# pylint:disable=unsupported-membership-test
# pylint:disable=unsubscriptable-object
# pylint:disable=unsupported-assignment-operation
"""Antennal Lobe Circuit
This module supports:
1. Changing parameter values of Biological Spike Generators (BSGs) associated with each
of the Local and Projection Neurons,
2. Changing the number and connectivity of Projection Neurons innervating a given AL
Glomerulus,
3. Changing the number and connectivity of Local Neurons in the Predictive Coding and
ON-OFF circuits of the AL
"""
from itertools import product
import copy
import typing as tp
from dataclasses import dataclass, field, asdict
import numpy as np
import networkx as nx
import olftrans as olf
import olftrans.fbl
import olftrans.data
from ..basecircuit import Circuit
from . import model as NDModel
from . import NDComponents as ndcomp
from ..antcircuits.circuit import ANTConfig
[docs]class ALException(Exception):
"""Base Antennal Lobe Exception"""
[docs]@dataclass
class ALConfig(ANTConfig):
# numbers
NP: tp.Union[int, tp.Iterable[int]] = None
"""Number of PNs, organized by Receptor Type"""
NPreLN: int = None
"""Number of Pre-synaptic Local Neurons"""
NPosteLN: tp.Union[int, tp.Iterable[int]] = None
"""Number of Post-synaptic Excitatory Local Neurons, organized by Receptor Type"""
NPostiLN: tp.Union[int, tp.Iterable[int]] = None
"""Number of Post-synaptic Inhibitory Local Neurons, organized by Receptor Type"""
# names
prelns: tp.Iterable[str] = field(repr=False, default=None)
postelns: tp.Iterable[tp.Iterable[str]] = field(repr=False, default=None)
postilns: tp.Iterable[tp.Iterable[str]] = field(repr=False, default=None)
pns: tp.Iterable[tp.Iterable[str]] = field(repr=False, default=None)
# routings
osn_to_preln: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False)
osn_to_postiln: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False)
osn_to_posteln: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False)
preln_to_axt: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False)
axt_to_pn: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False)
postiln_to_pn: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False)
posteln_to_pn: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False)
def __post_init__(self):
"""Set Variable Names and Default Routing Tables"""
super().__post_init__()
self.node_params["postelns"]["polarity"] = 1.0
self.node_params["postilns"]["polarity"] = -1.0
# set pn names
if np.isscalar(self.NP):
self.NP = np.full(self.NR, self.NP)
else:
assert len(self.NP) == self.NR
self.pns = self.set_or_assert(
self.pns,
[
[f"PN/{_or}/{p}" for p in range(self.NP[r])]
for r, _or in enumerate(self.receptors)
],
self.NP,
)
# set prelns names
self.prelns = self.set_or_assert(
self.prelns, [f"PreLN/{r}" for r in range(self.NPreLN)], self.NPreLN
)
# set posteln names
if np.isscalar(self.NPosteLN):
self.NPosteLN = np.full(self.NR, self.NPosteLN)
else:
assert len(self.NPosteLN) == self.NR
self.postelns = self.set_or_assert(
self.postelns,
[
[f"PostLN/e/{_or}/{p}" for p in range(self.NPosteLN[r])]
for r, _or in enumerate(self.receptors)
],
self.NPosteLN,
)
# set postiln names
if np.isscalar(self.NPostiLN):
self.NPostiLN = np.full(self.NR, self.NPostiLN)
else:
assert len(self.NPostiLN) == self.NR
self.postilns = self.set_or_assert(
self.postilns,
[
[f"PostLN/i/{_or}/{p}" for p in range(self.NPostiLN[r])]
for r, _or in enumerate(self.receptors)
],
self.NPostiLN,
)
self.osn_to_preln = self.set_or_assert_edges(
self.osn_to_preln, self.default_osn_to_preln(), self.NR
)
self.osn_to_postiln = self.set_or_assert_edges(
self.osn_to_postiln, self.default_osn_to_postiln(), self.NR
)
self.osn_to_posteln = self.set_or_assert_edges(
self.osn_to_posteln, self.default_osn_to_posteln(), self.NR
)
self.preln_to_axt = self.set_or_assert_edges(
self.preln_to_axt, self.default_preln_to_axt(), self.NR
)
self.axt_to_pn = self.set_or_assert_edges(
self.axt_to_pn, self.default_axt_to_pn(), self.NR
)
self.postiln_to_pn = self.set_or_assert_edges(
self.postiln_to_pn, self.default_postiln_to_pn(), self.NR
)
self.posteln_to_pn = self.set_or_assert_edges(
self.posteln_to_pn, self.default_posteln_to_pn(), self.NR
)
@property
def node_types(self) -> tp.List[str]:
"""List of Recognized Node Types"""
return [
"osn_otps",
"osn_bsgs",
"osn_alphas",
"osn_axts",
"prelns",
"postelns",
"postilns",
"pns",
]
@property
def routing_tables(self) -> tp.List[str]:
"""List of Recognized Routing Tables"""
return [
"osn_to_preln",
"osn_to_postiln",
"osn_to_posteln",
"preln_to_axt",
"axt_to_pn",
"postiln_to_pn",
"posteln_to_pn",
]
@property
def osn_otps(self):
return [[f"{name}/OTP" for name in names] for names in self.osns]
@property
def osn_bsgs(self):
return [[f"{name}/BSG" for name in names] for names in self.osns]
@property
def osn_alphas(self):
return [[f"{name}/ALP" for name in names] for names in self.osns]
@property
def osn_axts(self):
return [[f"{name}/AXT" for name in names] for names in self.osns]
[docs] def as_node_ids(self, table, source, target) -> tp.List[tp.List[str]]:
"""Convert Routing Table's indices to node Ids
Arguments:
table: routing table
source: source node ids
target: target node ids
Returns:
A flattenend list of all [source, target] node ids
"""
uids = []
for r, tab in enumerate(table):
if tab is None:
continue
try:
source_id = np.asarray(source)[r][tab[:, 0]]
except:
source_id = np.asarray(source)[tab[:, 0]]
try:
target_id = np.asarray(target)[r][tab[:, 1]]
except:
target_id = np.asarray(target)[tab[:, 1]]
uids += list(zip(source_id, target_id))
return uids
[docs] def set_or_assert_edges(self, array, new_array, size):
if array is None:
assert len(new_array) == size
array = new_array
else:
assert len(array) == size
return array
[docs] def default_osn_to_preln(self):
tbl = np.empty(self.NR, dtype=np.ndarray)
for r in range(self.NR):
conn = product(np.arange(self.NO[r]), np.arange(self.NPreLN))
tbl[r] = np.array(list(conn))
return tbl
[docs] def default_osn_to_posteln(self):
tbl = np.empty(self.NR, dtype=np.ndarray)
for r, (no, npln) in enumerate(zip(self.NO, self.NPosteLN)):
conn = product(np.arange(no), np.arange(npln))
tbl[r] = np.array(list(conn))
return tbl
[docs] def default_osn_to_postiln(self):
tbl = np.empty(self.NR, dtype=np.ndarray)
for r, (no, npln) in enumerate(zip(self.NO, self.NPostiLN)):
conn = product(np.arange(no), np.arange(npln))
tbl[r] = np.array(list(conn))
return tbl
[docs] def default_preln_to_axt(self):
tbl = np.empty(self.NR, dtype=np.ndarray)
for r in range(self.NR):
conn = product(np.arange(self.NPreLN), np.arange(self.NO[r]))
tbl[r] = np.array(list(conn))
return tbl
[docs] def default_postiln_to_pn(self):
tbl = np.empty(self.NR, dtype=np.ndarray)
for r in range(self.NR):
conn = product(np.arange(self.NPostiLN[r]), np.arange(self.NP[r]))
tbl[r] = np.array(list(conn))
return tbl
[docs] def default_posteln_to_pn(self):
tbl = np.empty(self.NR, dtype=np.ndarray)
for r in range(self.NR):
conn = product(np.arange(self.NPosteLN[r]), np.arange(self.NP[r]))
tbl[r] = np.array(list(conn))
return tbl
[docs] def default_axt_to_pn(self):
tbl = np.empty(self.NR, dtype=np.ndarray)
for r in range(self.NR):
conn = product(np.arange(self.NO[r]), np.arange(self.NP[r]))
tbl[r] = np.array(list(conn))
return tbl
[docs]@dataclass(repr=False)
class ALCircuit(Circuit):
"""Antennal Lobe Circuit"""
config: ALConfig
extra_comps: tp.List["NDComponent"] = field(
init=False, default_factory=lambda: NDModel.EXTRA_COMPS
)
[docs] @classmethod
def create_graph(cls, cfg) -> nx.MultiDiGraph:
G = nx.MultiDiGraph()
for r, (_otp_ids, _bsg_ids) in enumerate(zip(cfg.osn_otps, cfg.osn_bsgs)):
bsg_params = copy.deepcopy(NDModel.NoisyConnorStevens.params)
bsg_params.update(
{
key: val
for key, val in cfg.node_params["osn_bsgs"].items()
if not hasattr(val, "__len__")
}
)
otp_params = copy.deepcopy(NDModel.OTP.params)
otp_params.update({"br": cfg.brs[r], "dr": cfg.drs[r]})
otp_params.update(
{
key: val
for key, val in cfg.node_params["osn_otps"].items()
if key not in ["br", "dr"] and not hasattr(val, "__len__")
}
)
for _o_id, _b_id in zip(_otp_ids, _bsg_ids):
G.add_node(_o_id, **{"class": "OTP"}, **otp_params)
G.add_node(_b_id, **{"class": "NoisyConnorStevens"}, **bsg_params)
G.add_edge(_o_id, _b_id, variable="I")
cls.add_nodes_to_graph(G, cfg, "osn_alphas", "Alpha", NDModel)
cls.add_nodes_to_graph(G, cfg, "osn_axts", "OSNAxt2", NDModel)
cls.add_nodes_to_graph(G, cfg, "pns", "PN", NDModel)
cls.add_nodes_to_graph(G, cfg, "prelns", "PreLN", NDModel)
cls.add_nodes_to_graph(G, cfg, "postelns", "PostLN", NDModel)
cls.add_nodes_to_graph(G, cfg, "postilns", "PostLN", NDModel)
# connect nodes
G.add_edges_from(
zip(sum(cfg.osn_bsgs, []), sum(cfg.osn_alphas, [])), variable="spike_state"
)
G.add_edges_from(
zip(sum(cfg.osn_alphas, []), sum(cfg.osn_axts, [])), variable="g"
)
G.add_edges_from(
cfg.as_node_ids(cfg.osn_to_preln, cfg.osn_alphas, cfg.prelns), variable="g"
)
G.add_edges_from(
cfg.as_node_ids(cfg.osn_to_postiln, cfg.osn_alphas, cfg.postilns),
variable="g",
)
G.add_edges_from(
cfg.as_node_ids(cfg.osn_to_posteln, cfg.osn_alphas, cfg.postelns),
variable="g",
)
G.add_edges_from(
cfg.as_node_ids(cfg.preln_to_axt, cfg.prelns, cfg.osn_axts), variable="r"
)
G.add_edges_from(
cfg.as_node_ids(cfg.axt_to_pn, cfg.osn_axts, cfg.pns), variable="I"
)
G.add_edges_from(
cfg.as_node_ids(cfg.postiln_to_pn, cfg.postilns, cfg.pns), variable="I"
)
G.add_edges_from(
cfg.as_node_ids(cfg.posteln_to_pn, cfg.postelns, cfg.pns), variable="I"
)
return G
[docs] @classmethod
def create_from_config(cls, cfg) -> "ALCircuit":
"""Create Instance from Config
Arguments:
cfg: Config instance that specifies the configuration of the module
Returns:
A new ANTCircuit instance
"""
return cls(graph=cls.create_graph(cfg), config=cfg)
[docs] def set_node_params(
self,
node_type: "ALConfig.node_types",
key: str,
value: float,
receptor: tp.Union[str, tp.Iterable[str]] = None,
):
"""Set Parameter Value of selected Nodes"""
node_ids = self.get_node_ids(node_type=node_type, receptor=receptor)
if key == "sigma":
self.config.sigma = value
update_dct = {n: {key: value} for n in sum(node_ids, [])}
nx.set_node_attributes(self.graph, update_dct)
[docs] def set_neuron_number(
self,
node_type: "ALConfig.node_types",
number: int,
receptor: tp.Union[str, tp.Iterable[str]] = None,
) -> "ALCircuit":
"""Set Number of Neurons and change Routing Table Appropriately"""
if node_type not in self.config.node_types:
raise ALException(
f"Node Type {node_type} not found in graph"
f"must be one of {self.config.node_types}"
)
cfg = asdict(self.config)
def _set_PN(number):
_ = cfg.pop("axt_to_pn")
_ = cfg.pop("postiln_to_pn")
_ = cfg.pop("posteln_to_pn")
_ = cfg.pop("pns")
cfg.update(dict(NP=number))
return self.create_from_config(ALConfig(**cfg))
def _set_PreLN(number):
_ = cfg.pop("osn_to_preln")
_ = cfg.pop("preln_to_axt")
_ = cfg.pop("prelns")
cfg.update(dict(NPreLN=number))
return self.create_from_config(ALConfig(**cfg))
def _set_PosteLN(number):
_ = cfg.pop("osn_to_posteln")
_ = cfg.pop("posteln_to_pn")
_ = cfg.pop("postelns")
cfg.update(dict(NPreLN=number))
return self.create_from_config(ALConfig(**cfg))
def _set_PostiLN(number):
_ = cfg.pop("osn_to_postiln")
_ = cfg.pop("postiln_to_pn")
_ = cfg.pop("postilns")
cfg.update(dict(NPreLN=number))
return self.create_from_config(ALConfig(**cfg))
def _set_OSN(number):
_ = cfg.pop("osn_to_preln")
_ = cfg.pop("osns")
_ = cfg.pop("osn_to_postiln")
_ = cfg.pop("osn_to_posteln")
_ = cfg.pop("preln_to_axt")
_ = cfg.pop("axt_to_pn")
_ = cfg.update(dict(NPreLN=number))
return self.create_from_config(ALConfig(**cfg))
if "osn" in node_type:
return _set_OSN(number)
if "pn" in node_type:
return _set_PN(number)
if "posteln" in node_type:
return _set_PosteLN(number)
if "postiln" in node_type:
return _set_PosteLN(number)
if "preln" in node_type:
return _set_PreLN(number)
[docs] def set_routing(
self,
table: np.ndarray,
name: str,
receptor: tp.Union[str, tp.Iterable[str]] = None,
):
"""Seting Routing Table in Antennal Lobe"""
if not name in self.config.routing_tables:
raise ALException(
f"Attempting to set table {name}, "
f"Must be one of {self.config.routing_tables}"
)
if receptor is not None:
receptor = np.atleast_1d(receptor)
else:
receptor = self.config.receptors
if len(table) != len(receptor):
raise ALException(
"Table must be of shape "
"(len(receptor),) with each entry being the routing "
"in that particular channel"
)
cfg = asdict(self.config)
update_table = cfg[name]
for n, r in enumerate(receptor):
r_idx = list(self.config.receptors).index(r)
update_table[r_idx] = table[n]
cfg.update({name: update_table})
return self.create_from_config(ALConfig(**cfg))
@property
def inputs(self) -> dict:
"""Output OTP Nodes IDs and the Variables
Returns:
OTPs with input variable `conc`
"""
return {"conc": sum(self.config.osn_otps, [])}
@property
def outputs(self) -> dict:
"""Output BSG Nodes IDs and the Variables
Returns:
PNs with output variable `r`
"""
return {"r": sum(self.config.pns, [])}
[docs] def get_node_ids(
self,
node_type: "ALConfig.node_types",
receptor: tp.Union[str, tp.Iterable[str]] = None,
) -> list:
if receptor is None:
receptor = self.config.receptors
else:
receptor = np.atleast_1d(receptor)
for r in receptor:
if r not in self.config.receptors:
raise ALException(f"Receptors {r} not found in list of receptor names")
if node_type not in self.config.node_types:
raise ALException(
f"node_type {node_type} not recognized, "
f"must be one of {self.config.node_types}"
)
node_ids = getattr(self.config, node_type)
return [node_ids[list(self.config.receptors).index(r)] for r in receptor]