# pylint:disable=no-member
import os
from collections import OrderedDict
import numpy as np
import pycuda.gpuarray as garray
from pycuda.tools import dtype_to_ctype
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
from neurokernel.LPU.NDComponents.NDComponent import NDComponent
CURR_DIR = os.path.dirname(os.path.realpath(__file__))
[docs]class OTP(NDComponent):
accesses = ["conc"] # rate
updates = ["I"] # current
params = [
"br",
"dr",
"gamma",
"a1",
"b1",
"a2",
"b2",
"a3",
"b3",
"kappa",
"p",
"c",
"Imax",
]
internals = OrderedDict(
[("v", 0.0), ("uh", 0.0), ("duh", 0.0), ("x1", 0.0), ("x2", 0.0), ("x3", 0.0)]
)
[docs] def maximum_dt_allowed(self):
return 1e-4
def __init__(
self,
params_dict,
access_buffers,
dt,
LPU_id=None,
debug=False,
cuda_verbose=False,
):
if cuda_verbose:
self.compile_options = ["--ptxas-options=-v", "--expt-relaxed-constexpr"]
else:
self.compile_options = ["--expt-relaxed-constexpr"]
self.debug = debug
self.LPU_id = LPU_id
self.num_comps = params_dict[self.params[0]].size
self.dtype = params_dict[self.params[0]].dtype
self.dt = dt
self.params_dict = params_dict
self.access_buffers = access_buffers
self.internal_states = {
c: garray.zeros(self.num_comps, dtype=self.dtype) + self.internals[c]
for c in self.internals
}
self.inputs = {
k: garray.empty(self.num_comps, dtype=self.access_buffers[k].dtype)
for k in self.accesses
}
# self.retrieve_buffer_funcs = {}
# for k in self.accesses:
# self.retrieve_buffer_funcs[k] = \
# self.get_retrieve_buffer_func(
# k, dtype=self.access_buffers[k].dtype)
dtypes = {"dt": self.dtype}
dtypes.update(
{"input_" + k.format(k): self.inputs[k].dtype for k in self.accesses}
)
dtypes.update({"param_" + k: self.params_dict[k].dtype for k in self.params})
dtypes.update(
{"state_" + k: self.internal_states[k].dtype for k in self.internals}
)
dtypes.update(
{
"output_" + k: self.dtype if k != "spike_state" else np.int32
for k in self.updates
}
)
self.update_func = self.get_update_func(dtypes)
[docs] def run_step(self, update_pointers, st=None):
for k in self.inputs:
self.sum_in_variable(k, self.inputs[k], st=st)
# # retrieve all buffers into a linear array
# for k in self.inputs:
# self.retrieve_buffer(k, st=st)
self.update_func.prepared_async_call(
self.update_func.grid,
self.update_func.block,
st,
self.num_comps,
self.dt,
*[self.internal_states[k].gpudata for k in self.internals]
+ [self.params_dict[k].gpudata for k in self.params]
+ [self.inputs[k].gpudata for k in self.accesses]
+ [update_pointers[k] for k in self.updates]
)
[docs] def get_update_template(self, float_type):
with open(
os.path.join(os.path.dirname(CURR_DIR), "NK_kernels/OTP.cu"), "r"
) as f:
lines = f.read() # .replace('\n', '')
if not float_type in (np.double, np.float64):
from warnings import warn
warn("float_type {} not implemented, default to double".format(float_type))
float_type = np.double
return lines
[docs] def get_update_func(self, dtypes):
from pycuda.compiler import SourceModule
mod = SourceModule(
self.get_update_template(self.dtype), options=self.compile_options
)
func = mod.get_function("OTP")
type_dict = {k: dtype_to_ctype(dtypes[k]) for k in dtypes}
type_dict.update(
{"fletter": "f" if type_dict["param_" + self.params[0]] == "float" else ""}
)
func.prepare("i" + np.dtype(self.dtype).char + "P" * (len(type_dict) - 2))
func.block = (256, 1, 1)
func.grid = (
min(
6 * cuda.Context.get_device().MULTIPROCESSOR_COUNT,
(self.num_comps - 1) // 256 + 1,
),
1,
)
return func