Source code for eoscircuits.antcircuits.circuit

# pylint:disable=unsupported-membership-test
# pylint:disable=unsubscriptable-object
# pylint:disable=unsupported-assignment-operation
"""Antenna Circuit

This module supports:

1. Changing the affinity values of each of the odorant-receptor pairs
characterizing the input of the Odorant Transduction Process.
2. Changing parameter values of the Biological Spike Generators (BSGs)
associated with each OSN.
3. Changing the number of OSNs expressing the same Odorant Receptor (OR) type.
"""
import copy
import typing as tp
from dataclasses import dataclass, field
import numpy as np
import networkx as nx
from olftrans.olftrans import estimate_resting_spike_rate, estimate_sigma
from ..basecircuit import Config, Circuit, EOSCircuitException
from . import model as NDModel
from . import NDComponents as ndcomp


[docs]class ANTException(EOSCircuitException): """Base Antenna Exception"""
[docs]@dataclass class ANTConfig(Config): """Configuration for Antenna Circuits""" NO: tp.Iterable[tp.Iterable[int]] """Number of OSNs per Receptor Type""" affs: tp.Iterable[float] """Affinity Values""" receptors: tp.Iterable[str] = None """Name of receptors of length NR""" resting: float = None """Resting OSN Spike Rates [Hz]""" node_params: dict = field(default_factory=lambda: dict(osn_bsgs=dict(sigma=0.0025))) """Parameters for each neuron type""" osns: tp.Iterable[tp.Iterable[str]] = field(repr=False, default=None) """Ids of OSNs for each channel This is a list of list, where the outer list correspond to """ def __post_init__(self): for n in self.node_types: if n not in self.node_params: self.node_params[n] = dict() self.affs = np.asarray(self.affs) # set receptor names self.receptors = self.set_or_assert( self.receptors, [f"{r}" for r in range(self.NR)], self.NR ) # set osn names if np.isscalar(self.NO): self.NO = np.full((self.NR,), self.NO, dtype=int) else: if len(self.NO) != self.NR: raise ANTException( f"If `NO` is iterable, it has to have length same as affs." ) self.osns = self.set_or_assert( self.osns, [ [f"OSN/{_or}/{o}" for o in range(self.NO[r])] for r, _or in enumerate(self.receptors) ], self.NO, ) if self.drs is None: self.drs = np.full((self.NR,), 10.0) elif np.isscalar(self.drs): self.drs = np.full((self.NR,), self.drs) else: self.drs = np.asarray(self.drs) if len(self.drs) != self.NR: raise ANTException( "If Dissociation rate (dr) is specified as iterable, " "it needs to have length the same as affs." ) self.node_params["osn_otps"]["br"] = self.drs * self.affs if all([v is None for v in [self.resting, self.sigma]]): raise ANTException("Resting and Sigma cannot both be None") if self.resting is not None: self.sigma = estimate_sigma(self.resting)
[docs] def set_or_assert( self, var: "Config.Attribute", new_var: "Config.Attribute", N: np.ndarray ) -> "Config.Attribute": """Set Variable or Check Dimensionality If :code:`var` to new_names if None and perform dimensionality checks Arguments: var: old variable value new_var: new variable value N: dimensionality for the variable, could be multi-dimensional """ if var is None: if hasattr(N, "__len__"): assert len(new_var) == len(N) assert all([len(v) == n for v, n in zip(new_var, N)]) var = new_var else: if hasattr(N, "__len__"): assert len(new_var) == len(N) assert all([len(v) == n for v, n in zip(var, N)]) else: assert len(var) == N return var
[docs] def set_affs(self, new_affs): self.affs = new_affs self.brs = self.drs * self.affs
@property def node_types(self) -> tp.List[str]: return ["osn_otps", "osn_bsgs"] @property def osn_otps(self): return [[f"{name}/OTP" for name in names] for names in self.osns] @property def osn_bsgs(self): return [[f"{name}/BSG" for name in names] for names in self.osns] @property def NR(self) -> int: """Number of Receptors""" return len(self.affs) @property def sigma(self) -> float: """Noisy Connor Stevens model Noise Level""" return self.node_params["osn_bsgs"]["sigma"] @sigma.setter def sigma(self, new_sigma) -> float: self.node_params["osn_bsgs"]["sigma"] = new_sigma @property def brs(self) -> float: """Binding Rates of the OTPs""" if "br" in self.node_params["osn_otps"]: return self.node_params["osn_otps"]["br"] return None @property def drs(self) -> float: """Binding Rates of the OTPs""" if "dr" in self.node_params["osn_otps"]: return self.node_params["osn_otps"]["dr"] return None @drs.setter def drs(self, new_drs) -> float: new_drs = np.atleast_1d(new_drs) if len(new_drs) != self.NR: raise ANTException( f"dr values length mismatch, expected {self.NR}, " f"got {len(new_drs)}" ) self.node_params["osn_otps"]["dr"] = new_drs
[docs]@dataclass(repr=False) class ANTCircuit(Circuit): """Antenna Circuit""" config: ANTConfig extra_comps: tp.List["NDComponent"] = field( init=False, default_factory=lambda: [ndcomp.NoisyConnorStevens, ndcomp.OTP] )
[docs] @classmethod def create_graph(cls, cfg) -> nx.MultiDiGraph: G = nx.MultiDiGraph() for r, (_otp_ids, _bsg_ids) in enumerate(zip(cfg.osn_otps, cfg.osn_bsgs)): bsg_params = copy.deepcopy(NDModel.NoisyConnorStevens.params) bsg_params.update( { key: val for key, val in cfg.node_params["osn_bsgs"].items() if not hasattr(val, "__len__") } ) otp_params = copy.deepcopy(NDModel.OTP.params) otp_params.update({"br": cfg.brs[r], "dr": cfg.drs[r]}) otp_params.update( { key: val for key, val in cfg.node_params["osn_otps"].items() if key not in ["br", "dr"] and not hasattr(val, "__len__") } ) for _o_id, _b_id in zip(_otp_ids, _bsg_ids): G.add_node(_o_id, **{"class": "OTP"}, **otp_params) G.add_node(_b_id, **{"class": "NoisyConnorStevens"}, **bsg_params) G.add_edge(_o_id, _b_id, variable="I") return G
[docs] @classmethod def create_from_config(cls, cfg) -> "ANTCircuit": """Create Instance from Config Arguments: cfg: Config instance that specifies the configuration of the module Returns: A new ANTCircuit instance """ return cls(graph=cls.create_graph(cfg), config=cfg)
[docs] def set_affinities(self, value, receptors=None) -> None: """Set Affinity values. .. note:: Because binding rates are computed from affinities :code:`config.affs` and dissociations rates :code:`config.drs`, change affinities will have effect of changing binding rates but not dissociation rates. """ if receptors is None: receptors = list(self.config.receptors) else: receptors = list(np.atleast_1d(receptors)) value = np.atleast_1d(value) if len(value) != len(receptors): raise ANTException( f"Attempting to set values of length {len(value)} into " f"{len(receptors)} receptors" ) for r in receptors: r_idx = list(self.config.receptors).index(r) new_aff = value[r_idx] self.config.affs[r_idx] = new_aff otp_nodes = self.config.osn_otps[r_idx] update_dct = { n: {"br": self.graph.nodes[n]["dr"] * new_aff} for n in otp_nodes } nx.set_node_attributes(self.graph, update_dct)
[docs] def set_bsg_params(self, key: str, value: float) -> None: """Set parameter value of BSG nodes""" if key == "sigma": self.config.sigma = value update_dict = {n: {key: value} for n in sum(self.config.osn_bsgs, [])} nx.set_node_attributes(self.graph, update_dict)
[docs] def set_NO( self, NO: tp.Union[int, tp.Iterable[int]], receptor=None, aff_noise_std=0.0 ) -> None: """Change number of OSNs expressing each receptor type""" if receptor is None: receptor = list(self.config.receptors) else: receptor = list(np.atleast_1d(receptor)) if any([r not in self.config.receptors for r in receptor]): raise ANTException("Receptors not found in list of names") for r in receptor: r_idx = list(self.config.receptors).index(r) self.config.NO[r_idx] = NO self.config.osns[r_idx] = [f"OSN/{r}/{n}" for n in range(NO)] self.graph = self.create_graph(self.config)
[docs] def get_node_ids( self, node_type: "ANTConfig.node_types", receptor: tp.Union[str, tp.Iterable[str]] = None, ) -> list: if receptor is None: receptor = self.config.receptors else: receptor = np.atleast_1d(receptor) for r in receptor: if r not in self.config.receptors: raise ANTException(f"Receptors {r} not found in list of receptor names") if node_type not in self.config.node_types: raise ANTException( f"node_type {node_type} not recognized, " f"must be one of {self.config.node_types}" ) node_ids = getattr(self.config, node_type) return [node_ids[list(self.config.receptors).index(r)] for r in receptor]
@property def inputs(self) -> dict: """Output OTP Nodes IDs and the Variables""" return {"conc": sum(self.config.osn_otps, [])} @property def outputs(self) -> dict: """Output BSG Nodes IDs and the Variables""" bsg_ids = sum(self.config.osn_bsgs, []) return {"V": bsg_ids, "spike_state": bsg_ids}