Source code for olftrans.neurodriver.model

"""NeuroDriver Models and Utilities

Examples:
    1. Compute F-I 
    >>> from olftrans.neurodriver import model
    >>> import numpy as np
    >>> dt = 5e-6
    >>> repeat = 50
    >>> Is = np.linspace(0,150,150)
    >>> _, fs = model.compute_fi(model.NoisyConnorStevens, Is, dt=dt, repeat=repeat, save=True)
    
    2. Compute Resting Spike Rate
    >>> from olftrans.neurodriver import model
    >>> import numpy as np
    >>> dt = 5e-6
    >>> repeat = 50
    >>> sigmas = np.linspace(0,0.005,150)
    >>> _, rest_fs = model.compute_resting(model.NoisyConnorStevens, 'sigma', sigmas/np.sqrt(dt), dt=dt, repeat=repeat, save=True)

    3. Compute Peak and SS Currents of OTP
    >>> from olftrans.neurodriver import model
    >>> import numpy as np
    >>> dt, amplitude = 5e-6, 100.
    >>> br_s = np.linspace(1e-2, 1000., 50)
    >>> dr_s = np.linspace(1e-2, 1000., 50)
    >>> _, _, I_ss, I_peak = model.compute_peak_ss_I(br_s, dr_s, dt=dt, amplitude=amplitude, save=True)
"""
from collections import OrderedDict
from neuroballad.models.element import Element
import numpy as np
import copy
import networkx as nx
import typing as tp
import os
from olftrans import ROOTDIR, DATADIR
from scipy.signal import savgol_filter
from tqdm import tqdm

from . import NDComponents
from ..data import data
from .. import utils


[docs]class Model(Element): """NeuroBallad Element that also wraps the underlying NDComponent""" _ndcomp = None
[docs]class OTP(Model): """ Odorant Transduction Process """ element_class = "neuron" states = OrderedDict( [("v", 0.0), ("uh", 0.0), ("duh", 0.0), ("x1", 0.0), ("x2", 0.0), ("x3", 0.0)] ) params = dict( br=1.0, dr=10.0, gamma=0.215, b1=0.8, a1=45.0, a2=146.1, b2=117.2, a3=2.539, b3=0.9096, kappa=8841.0, p=1.0, c=0.06546, Imax=62.13, ) _ndcomp = NDComponents.OTP
[docs]class NoisyConnorStevens(Model): """ Noisy Connor-Stevens Neuron Model F-I curve is controlled by `sigma` parameter Notes: `sigma` value should be scaled by `sqrt(dt)` as `sigma/sqrt(dt)` where `sigma` is the standard deviation of the Brownian Motion """ states = dict(n=0.0, m=0.0, h=1.0, a=1.0, b=1.0, v1=-60.0, v2=-60.0, refactory=0.0) params = dict( ms=-5.3, ns=-4.3, hs=-12.0, gNa=120.0, gK=20.0, gL=0.3, ga=47.7, ENa=55.0, EK=-72.0, EL=-17.0, Ea=-75.0, sigma=2.05, refperiod=1.0, ) _ndcomp = NDComponents.NoisyConnorStevens
[docs]def compute_fi( NeuronModel: Model, Is: np.ndarray, repeat: int = 1, input_var: str = "I", spike_var: str = "spike_state", dur: float = 2.0, start: float = 0.5, dt: float = 1e-5, neuron_params: dict = None, save: bool = True, ) -> tp.Tuple[np.ndarray, np.ndarray]: """Compute Frequency-Current relationship of Neuron Notes: If `save==True`, `olftrans.data.data.olfdata.save` is called. Returns: Is: 1d array of Currents spike_rates: 1d array Spiking Frequencies, dimension matches param_values Examples: Basic Usage: >>> from olftrans.neurodriver.model import NoisyConnorStevens >>> Is = np.linspace(0., 150, 100) >>> dt = 1e-5 >>> _, fs = compute_fi(NoisyConnorStevens, Is, repeat=50, dt=dt, neuron_params={'sigma': 0.005/np.sqrt(dt)}) We can look for the input current value from spike rate >>> target_spike_rate = 150. # [Hz] >>> target_I = np.interp(x=target_spike_rate, xp=fs, fp=fs) """ from neurokernel.LPU.LPU import LPU from neurokernel.LPU.InputProcessors.StepInputProcessor import StepInputProcessor from neurokernel.LPU.OutputProcessors.OutputRecorder import OutputRecorder stop = dur t = np.arange(0.0, dur, dt) clsname = NeuronModel.__name__ Is = np.atleast_1d(Is) neuron_params = neuron_params or {} params = copy.deepcopy(NeuronModel.params) params.update(neuron_params) G = nx.MultiDiGraph() csn_ids = np.empty((len(Is), repeat), dtype=object) for n_I, _I in enumerate(Is): for r in range(repeat): _id = f"{clsname}-I{n_I}-{r}" G.add_node(_id, **{"label": _id, "class": clsname}, **params) csn_ids[n_I, r] = _id fi = StepInputProcessor( variable=input_var, uids=csn_ids.ravel().astype(str), val=np.repeat(Is, repeat), start=start, stop=stop, ) fo = OutputRecorder([(spike_var, None)]) lpu = LPU( dt, "obj", G, device=0, id=f"F-I {clsname}", input_processors=[fi], output_processors=[fo], debug=False, manager=False, extra_comps=[NeuronModel._ndcomp], ) lpu.run(steps=len(t)) Nspikes = np.zeros((len(Is), repeat)) spikes = fo.get_output(var=spike_var) for n_I, _I in enumerate(Is): for r in range(repeat): _id = f"{clsname}-I{n_I}-{r}" Nspikes[n_I, r] = np.sum( np.logical_and( spikes[_id]["data"] >= start, spikes[_id]["data"] <= stop ) ) spike_rates = Nspikes.mean(-1) / (stop - start) if save: data.olfdata.save( "FI", data=data.DataFI( Model=clsname, Currents=Is, Frequencies=spike_rates, InputVar=input_var, SpikeVar=spike_var, Params={ k: val if k != "sigma" else val * np.sqrt(dt) for k, val in neuron_params.items() }, Repeats=repeat, ), metadata=data.DataMetadata(dt=dt, dur=dur, start=start, stop=stop), ) return Is, spike_rates
[docs]def compute_peak_ss_I( br_s: np.ndarray, dr_s: np.ndarray, dt: float = 1e-5, dur: float = 2.0, start: float = 0.5, save: bool = True, amplitude: float = 100.0, steady_state_compute_time=None, ) -> tp.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Compute Peak and Steady-State Current output of OTP Model Notes: if `save==True`, `olftrans.data.data.olfdata.save` is called. Returns: br_s: 1d array of binding rate dr_s: 1d array of dissociation rate I_ss: 2d array of resultant steady-state currents I_peak: 2d array of resultant peak currents Examples: Basic Usage: >>> br_s = np.linspace(1e-1, 100., 100) >>> dr_s = np.linspace(1e-1, 100., 100) >>> _, _, I_ss, I_peak = compute_peak_ss_I(br_s, dr_s, save=True) Plotting Steady-State Current against affinity: >>> DR, BR = np.meshgrid(dr_s, br_s) >>> plt.plot((BR/DR).ravel(), I_ss.ravel()) # steady-state current is only dependent on dissociation We can also look for the affinity value from steady-state Current output >>> target_I = 50. # [uA] >>> target_aff = np.interp(x=target_I, xp=I_ss.ravel(), fp=(BR/DR).ravel()) """ from neurokernel.LPU.LPU import LPU from neurokernel.LPU.InputProcessors.StepInputProcessor import StepInputProcessor from neurokernel.LPU.OutputProcessors.OutputRecorder import OutputRecorder stop = dur if steady_state_compute_time is None: steady_state_compute_time = stop t = np.arange(0, dur, dt) G = nx.MultiDiGraph() otp_ids = np.empty((len(br_s), len(dr_s)), dtype=object) for n_b, _br in enumerate(br_s): for n_d, _dr in enumerate(br_s): _id = f"OTP-B{n_b}-D{n_d}" _params = copy.deepcopy(OTP.params) _params.update(dict(br=_br, dr=_dr)) G.add_node(_id, **{"label": _id, "class": "OTP"}, **_params) otp_ids[n_b, n_d] = _id fi = StepInputProcessor( variable="conc", uids=otp_ids.ravel().astype(str), val=amplitude, start=start, stop=stop, ) fo = OutputRecorder([("I", None)]) lpu = LPU( dt, "obj", G, device=0, id="OTP Currents", input_processors=[fi], output_processors=[fo], debug=False, manager=False, extra_comps=[OTP._ndcomp], ) lpu.run(steps=len(t)) _ss = fo.output["I"]["data"][ np.logical_and( t >= steady_state_compute_time - dt, t <= steady_state_compute_time + dt ) ].mean(0) _peak = fo.output["I"]["data"].max(0) I_ss = np.zeros((len(br_s), len(dr_s))) I_peak = np.zeros((len(br_s), len(dr_s))) uids = list(fo.output["I"]["uids"]) for n_b, _br in enumerate(br_s): for n_d, _dr in enumerate(br_s): _id = f"OTP-B{n_b}-D{n_d}" idx = uids.index(_id) I_ss[n_b, n_d] = _ss[idx] I_peak[n_b, n_d] = _peak[idx] if save: data.olfdata.save( "OTP", data=data.DataOTP( Model="OTP", Amplitude=amplitude, Br=br_s, Dr=dr_s, Peak=I_peak, SS=I_ss ), metadata=data.DataMetadata(dt=dt, dur=dur, start=start, stop=stop), ) return br_s, dr_s, I_ss, I_peak
[docs]def compute_resting( NeuronModel: Model, param_key: str, param_values: np.ndarray, neuron_params: dict = None, repeat: int = 1, input_var: str = "I", spike_var: str = "spike_state", dur: float = 2.0, dt: float = 1e-5, save: bool = True, smoothen: bool = True, savgol_window: int = 15, savgol_order: int = 3, ) -> tp.Tuple[np.ndarray, np.ndarray]: """Compute Resting Spike Rate of a Neuron as Parameter varies Arguments: NeuronModel: Model to be used to compute Resting Spike Rate param_key: Parameter of the model to sweep param_values: values of the parameter to sweep neuron_params: other parameters to fix repeat: number of times the same parameter value is repeated on neuron models This is for noise reduction purposes input_var: variable name of the input variable for the neuron model spike_var: variable name of the spike variable for the neuron model dur: duration of simulation dt: time resolution of the simulation save: whether to save the output Notes: if `save==True`, `olftrans.data.data.olfdata.save` is called. Returns: param_values: 1d array of param_values spike_rates: 1d array Spiking Frequencies, dimension matches param_values Examples: Basic Usage: >>> from olftrans.neurodriver.model import NoisyConnorStevens >>> sigmas = np.linspace(0., 0.005, 100) >>> _, fs = compute_resting(NoisyConnorStevens, 'sigma', sigmas, repeat=50) We can look for the parameter value from resting spike rate >>> target_resting = 8. # [Hz] >>> target_sigma = np.interp(x=target_resting, xp=fs, fp=sigmas) """ from neurokernel.LPU.LPU import LPU from neurokernel.LPU.InputProcessors.StepInputProcessor import StepInputProcessor from neurokernel.LPU.OutputProcessors.OutputRecorder import OutputRecorder start, stop = 0.0, dur t = np.arange(0.0, dur, dt) clsname = NeuronModel.__name__ param_values = np.atleast_1d(param_values) neuron_params = neuron_params or {} G = nx.MultiDiGraph() csn_ids = np.empty((len(param_values), repeat), dtype=object) for n_p, val in enumerate(param_values): params = copy.deepcopy(NeuronModel.params) params.update(neuron_params) params.update({param_key: val}) for r in range(repeat): _id = f"{clsname}-P{n_p}-{r}" G.add_node(_id, **{"label": _id, "class": clsname}, **params) csn_ids[n_p, r] = _id fi = StepInputProcessor( variable=input_var, uids=csn_ids.ravel().astype(str), val=0.0, start=start, stop=stop, ) fo = OutputRecorder([(spike_var, None)]) lpu = LPU( dt, "obj", G, device=0, id=f"Resting Spike Rate {clsname} - Against {param_key}", input_processors=[fi], output_processors=[fo], debug=False, manager=False, extra_comps=[NeuronModel._ndcomp], ) lpu.run(steps=len(t)) Nspikes = np.zeros((len(param_values), repeat)) spikes = fo.get_output(var="spike_state") for n_p, val in enumerate(param_values): for r in range(repeat): _id = f"{clsname}-P{n_p}-{r}" Nspikes[n_p, r] = np.sum( np.logical_and( spikes[_id]["data"] >= start, spikes[_id]["data"] <= stop ) ) spike_rates = Nspikes.mean(-1) / (stop - start) if smoothen: spike_rates = savgol_filter(spike_rates, savgol_window, savgol_order) if save: data.olfdata.save( "REST", data=data.DataRest( Model=clsname, ParamKey=param_key, ParamValue=param_values, Smoothen=smoothen, Frequencies=spike_rates, InputVar=input_var, SpikeVar=spike_var, Params={ k: val if k != "sigma" else val * np.sqrt(dt) for k, val in neuron_params.items() }, Repeats=repeat, ), metadata=data.DataMetadata( dt=dt, dur=dur, start=start, stop=stop, savgol_window=savgol_window, savgol_order=savgol_order, ), ) return param_values, spike_rates
[docs]def compute_peak_ss_spike_rate( br_s: np.ndarray, dr_s: np.ndarray, repeat: int = 1, dt: float = 1e-5, dur: float = 2.0, start: float = 0.5, save: bool = True, amplitude: float = 100.0, neuron_params: dict = None, ) -> tp.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Compute Peak and Steady-State Spike Rate output of OTP-BSG Cascade Notes: if `save==True`, `olftrans.data.data.olfdata.save` is called. Returns: br_s: 1d array of binding rate dr_s: 1d array of dissociation rate spikerate_ss: 2d array of resultant steady-state currents spikerate_peak: 2d array of resultant peak currents """ from neurokernel.LPU.LPU import LPU from neurokernel.LPU.InputProcessors.StepInputProcessor import StepInputProcessor from neurokernel.LPU.OutputProcessors.OutputRecorder import OutputRecorder neuron_params = neuron_params or {} stop = dur t = np.arange(0, dur, dt) G = nx.MultiDiGraph() otp_ids = np.empty((len(br_s), len(dr_s)), dtype=object) bsg_ids = np.empty((len(br_s), len(dr_s), repeat), dtype=object) for n_b, _br in enumerate(br_s): for n_d, _dr in enumerate(br_s): otp_id = f"OTP-B{n_b}-D{n_d}" _params = copy.deepcopy(OTP.params) _params.update(dict(br=_br, dr=_dr)) G.add_node(otp_id, **{"label": otp_id, "class": "OTP"}, **_params) otp_ids[n_b, n_d] = otp_id for n_r in range(repeat): bsg_id = f"BSG-B{n_b}-D{n_d}-R{n_r}" _params = copy.deepcopy(NoisyConnorStevens.params) _params.update(neuron_params) G.add_node( bsg_id, **{"label": bsg_id, "class": "NoisyConnorStevens"}, **_params, ) bsg_ids[n_b, n_d, n_r] = bsg_id G.add_edge(otp_id, bsg_id, variable="I") fi = StepInputProcessor( variable="conc", uids=otp_ids.ravel().astype(str), val=amplitude, start=start, stop=stop, ) fo = OutputRecorder( [("I", None), ("spike_state", None)], sample_interval=int(1e-3 // dt) ) lpu = LPU( dt, "obj", G, device=0, id="OTP-BSG Peak vs. SS", input_processors=[fi], output_processors=[fo], debug=False, manager=False, extra_comps=[OTP._ndcomp, NoisyConnorStevens._ndcomp], ) lpu.run(steps=len(t)) print("Computing Peak and Steady State Currents") _ss = fo.output["I"]["data"][-1] _peak = fo.output["I"]["data"].max(0) I_ss = np.zeros((len(br_s), len(dr_s))) I_peak = np.zeros((len(br_s), len(dr_s))) uids = list(fo.output["I"]["uids"]) for n_b, _br in enumerate(br_s): for n_d, _dr in enumerate(br_s): _id = f"OTP-B{n_b}-D{n_d}" idx = uids.index(_id) I_ss[n_b, n_d] = _ss[idx] I_peak[n_b, n_d] = _peak[idx] print("Computing Peak and Steady State Spike Rates") psths = np.empty((len(br_s), len(dr_s)), dtype=np.ndarray) pbar = tqdm(total=len(br_s) * len(dr_s), desc="Computing PSTH...") uids = list(fo.output["spike_state"]["uids"]) for n_b, _br in enumerate(br_s): for n_d, _dr in enumerate(dr_s): pbar.update() spike_states = np.zeros((len(t), repeat)) for n_r in range(repeat): bsg_id = f"BSG-B{n_b}-D{n_d}-R{n_r}" idx = uids.index(bsg_id) mask = fo.output["spike_state"]["data"]["index"] == idx _spikes = fo.output["spike_state"]["data"]["time"][mask] spike_states[((_spikes - dt / 2) // dt).astype(int), n_r] = 1 psth, psth_t = utils.compute_psth(spike_states, dt, 2e-2, 1.5e-2) psths[n_b, n_d] = psth pbar.close() psths_cascade = np.zeros((len(br_s), len(dr_s), len(psth_t))) for n_b, _br in enumerate(br_s): for n_d, _dr in enumerate(dr_s): psths_cascade[n_b, n_d] = psths[n_b, n_d] fs_ss = psths_cascade[..., psth_t > (psth_t.max() - 0.2)].mean(-1) fs_peak = psths_cascade.max(-1) return br_s, dr_s, I_ss, I_peak, fs_ss, fs_peak