Source code for eoscircuits.mbcircuits.circuit

# pylint:disable=unsupported-membership-test
# pylint:disable=unsubscriptable-object
# pylint:disable=unsupported-assignment-operation
# pylint:disable=no-member
"""Mushroom Body Circuit

This module supports:

1. Generating and changing random connectivity patterns between PNs and KCs with
varying degree of fan-in ratio (number of PNs connected to a given KC)
2. Changing the strength of feedback inhibition of the APL neuron
"""
from itertools import product
import copy
import typing as tp
from dataclasses import dataclass, field, asdict
import numpy as np
import networkx as nx
import olftrans as olf
import olftrans.fbl
import olftrans.data
from ..alcircuits.circuit import ALConfig
from ..basecircuit import Circuit, EOSCircuitException

from . import model as NDModel
from . import NDComponents as ndcomp


[docs]class MBException(EOSCircuitException): """Base Mushroom Body Exception"""
[docs]@dataclass class MBConfig(ALConfig): NK: int = None """Number of KCs""" NAPL: int = None """Number of APLs""" NFanIn: int = 6 """Number of PNs (regardless of receptor type) connected to each KC""" # names kcs: tp.Iterable[str] = field(repr=False, default=None) apls: tp.Iterable[str] = field(repr=False, default=None) # routings pn_to_kc: tp.Iterable[tp.Iterable[float]] = field(default=None, repr=False) kc_to_apl: tp.Iterable[float] = field(default=None, repr=False) apl_to_kc: tp.Iterable[float] = field(default=None, repr=False) def __post_init__(self): super().__post_init__() # set kc names self.kcs = self.set_or_assert( self.kcs, [f"KC/{k}" for k in range(self.NK)], self.NK ) self.apls = self.set_or_assert( self.apls, [f"APL/{a}" for a in range(self.NAPL)], self.NAPL ) self.pn_to_kc = self.set_or_assert_edges( self.pn_to_kc, self.default_pn_to_kc(), self.NR ) self.kc_to_apl = self.set_or_assert_edges( self.kc_to_apl, self.default_kc_to_apl(), self.NK ) self.apl_to_kc = self.set_or_assert_edges( self.apl_to_kc, self.default_apl_to_kc(), self.NAPL ) @property def node_types(self): return [ "osn_otps", "osn_bsgs", "osn_alphas", "osn_axts", "prelns", "postelns", "postilns", "pns", "kcs", "kcdends", "apls", ] @property def kcdends(self): return [f"{n}/DEND" for n in self.kcs]
[docs] def default_pn_to_kc(self, fanin: int = None, seed: int = None) -> np.ndarray: """PN to KC Connectivity The indices of PNs connected to KC is sampled uniformly from all the PNs, regardless of receptor types. Arguments: fanin: number of PNs connected to a single KC seed: seed for random number generator """ if fanin is not None: if not np.isscalar(fanin): raise MBException("Only supports integer value for PN-to-KC Fan-In") if fanin <= 0 or fanin > self.NK: raise MBException( f"PN-to-KC Fan-In can only be in range [1, {self.NK}], " f"got {fanin}" ) self.NFanIn = fanin rng = np.random.RandomState(seed) idx = np.arange(np.sum(self.NP)) # index of all PNs from all glomeruli pn_kc_mat = np.zeros((self.NK, np.sum(self.NP))) for i in range(self.NK): rng.shuffle(idx) pn_kc_mat[i][idx[: self.NFanIn]] = 1 pn_indices = [0] + list(np.cumsum(self.NP)) # start-stop of PN in each glom tbl = np.empty(self.NR, dtype=np.ndarray) for r, (start, stop) in enumerate(zip(pn_indices[:-1], pn_indices[1:])): kc_idx, pn_idx = np.nonzero(pn_kc_mat[:, start:stop]) pn_to_kc_idx = np.array(list(zip(pn_idx, kc_idx))) tbl[r] = np.sort(pn_to_kc_idx, axis=0) return tbl
[docs] def default_kc_to_apl(self): """Create Densely Connected KC to APL routing table""" tbl = np.empty(self.NK, dtype=np.ndarray) for k in range(self.NK): conn = product([0], np.arange(self.NAPL)) tbl[k] = np.array(list(conn)) return tbl
[docs] def default_apl_to_kc(self): """Create Densely Connected APL to KC routing table""" tbl = np.empty(self.NAPL, dtype=np.ndarray) for a in range(self.NAPL): conn = product([0], np.arange(self.NK)) tbl[a] = np.array(list(conn)) return tbl
[docs]@dataclass class MBCircuit(Circuit): """Mushroom Body Circuit""" config: MBConfig extra_comps: tp.List["NDComponent"] = field( init=False, default_factory=lambda: NDModel.EXTRA_COMPS )
[docs] @classmethod def create_graph(cls, cfg) -> nx.MultiDiGraph: G = nx.MultiDiGraph() for r, (_otp_ids, _bsg_ids) in enumerate(zip(cfg.osn_otps, cfg.osn_bsgs)): bsg_params = copy.deepcopy(NDModel.NoisyConnorStevens.params) bsg_params.update( { key: val for key, val in cfg.node_params["osn_bsgs"].items() if not hasattr(val, "__len__") } ) otp_params = copy.deepcopy(NDModel.OTP.params) otp_params.update({"br": cfg.brs[r], "dr": cfg.drs[r]}) otp_params.update( { key: val for key, val in cfg.node_params["osn_otps"].items() if key not in ["br", "dr"] and not hasattr(val, "__len__") } ) for _o_id, _b_id in zip(_otp_ids, _bsg_ids): G.add_node(_o_id, **{"class": "OTP"}, **otp_params) G.add_node(_b_id, **{"class": "NoisyConnorStevens"}, **bsg_params) G.add_edge(_o_id, _b_id, variable="I") cls.add_nodes_to_graph(G, cfg, "osn_alphas", "Alpha", NDModel) cls.add_nodes_to_graph(G, cfg, "osn_axts", "OSNAxt2", NDModel) cls.add_nodes_to_graph(G, cfg, "pns", "PN", NDModel) cls.add_nodes_to_graph(G, cfg, "prelns", "PreLN", NDModel) cls.add_nodes_to_graph(G, cfg, "postelns", "PostLN", NDModel) cls.add_nodes_to_graph(G, cfg, "postilns", "PostLN", NDModel) cls.add_nodes_to_graph(G, cfg, "kcs", "KC", NDModel) cls.add_nodes_to_graph(G, cfg, "kcdends", "KCDend", NDModel) cls.add_nodes_to_graph(G, cfg, "apls", "APL", NDModel) # connect nodes G.add_edges_from( zip(sum(cfg.osn_bsgs, []), sum(cfg.osn_alphas, [])), variable="spike_state" ) G.add_edges_from( zip(sum(cfg.osn_alphas, []), sum(cfg.osn_axts, [])), variable="g" ) G.add_edges_from( cfg.as_node_ids(cfg.osn_to_preln, cfg.osn_alphas, cfg.prelns), variable="g" ) G.add_edges_from( cfg.as_node_ids(cfg.osn_to_postiln, cfg.osn_alphas, cfg.postilns), variable="g", ) G.add_edges_from( cfg.as_node_ids(cfg.osn_to_posteln, cfg.osn_alphas, cfg.postelns), variable="g", ) G.add_edges_from( cfg.as_node_ids(cfg.preln_to_axt, cfg.prelns, cfg.osn_axts), variable="r" ) G.add_edges_from( cfg.as_node_ids(cfg.axt_to_pn, cfg.osn_axts, cfg.pns), variable="I" ) G.add_edges_from( cfg.as_node_ids(cfg.postiln_to_pn, cfg.postilns, cfg.pns), variable="I" ) G.add_edges_from( cfg.as_node_ids(cfg.posteln_to_pn, cfg.postelns, cfg.pns), variable="I" ) G.add_edges_from( cfg.as_node_ids(cfg.pn_to_kc, cfg.pns, cfg.kcdends), variable="r" ) G.add_edges_from(zip(cfg.kcdends, cfg.kcs), variable="I") G.add_edges_from( cfg.as_node_ids(cfg.kc_to_apl, cfg.kcs, cfg.apls), variable="r" ) G.add_edges_from( cfg.as_node_ids(cfg.apl_to_kc, cfg.apls, cfg.kcdends), variable="g" ) return G
[docs] @classmethod def create_from_config(cls, cfg) -> "ALCircuit": """Create Instance from Config Arguments: cfg: Config instance that specifies the configuration of the module Returns: A new ANTCircuit instance """ return cls(graph=cls.create_graph(cfg), config=cfg)
[docs] def change_apl_strength(self, N: float) -> None: """Set APL Strength Arguments: N: The larger the N the weaker the inhibition. N should be in the range of [1, inf]. Typical values are around :code:`config.NK` """ N = np.clip(N, 1, np.inf) nx.set_node_attributes( self.graph, {n: {"N": float(N)} for n in self.config.apls} )
[docs] def change_pn_to_kc( self, routing: np.ndarray = None, fanin: int = None, seed: int = None ) -> None: if routing is not None: if routing.shape != self.config.pn_to_kc.shape: raise MBException("PN-to-KC Routing shape mismatch.") self.config.pn_to_kc = routing else: routing = self.config.default_pn_to_kc(fanin=fanin, seed=seed) self.config.pn_to_kc = routing self.graph.remove_edges_from( set(self.graph.in_edges(self.config.kcdends)).intersection( set(self.graph.out_edges(sum(self.config.pns, []))) ) ) self.graph.add_edges_from( self.config.as_node_ids( self.config.pn_to_kc, self.config.pns, self.config.kcdends ), variable="r", )
@property def inputs(self) -> dict: """Output OTP Nodes IDs and the Variables""" return {"conc": sum(self.config.osn_otps, [])} @property def outputs(self) -> dict: """Output BSG Nodes IDs and the Variables""" return {"r": sum(self.config.kcs, [])}