"""Plotting functions"""
import typing as tp
from warnings import warn
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import ticker
import numpy as np
[docs]def plot_multiple(
data_x: np.ndarray, *args, **kwargs
) -> tp.Tuple[plt.Figure, np.ndarray]:
"""
Plot multiple data curves against a same x-axis on mulitple subplots.
Arguments:
datax (darray): the data point on the x-axis.
*args: each entry of args is a list containing multiple sets of data
and parameters that will be plotted in the same subplot.
An entry should follow the format `(data_1, param_1, ...)`,
where each of the `data_i` is a numpy array, and each of the
`param_i` is a `dict` of the parameters for ploting `data_i` against
`data_x`. Alternatively, an entry can simply be an numpy array. In
this case, only one curve will be plotted in the corresponding
subplot.
Keyword Arguments:
figw (float): the figure width.
figh (float): the figure height.
xlabel (str): the label of the x-axis.
The additional keyword arguments will propagate into the private
plotting method `_plot`, and eventually into the `pyplot.plot` method.
"""
def _plot(axe, data_x, data_y, **kwargs):
"""
Arguments:
axe (matplotlib.Axes.axe): the axe of the subplot.
data_x (darray): the data point along the x-axis.
data_y (darray): the data point along the y-axis.
Keyword Arguments:
xlim (tuple): a tuple-like with two entries of limits of the x-axis.
ylim (tuple): a tuple-like with two entries of limits of the y-axis.
spike (bool): specify if `data_y` is a spike sequence.
ylabel (str): the label of the y-axis.
ds_rate (int): the downsample rate of the data.
The additional keyword arguments will propagate into the
`pyplot.plot` method. For example, one could use `label` to add a
legend to a curve.
"""
xlim = kwargs.pop("xlim", None)
ylim = kwargs.pop("ylim", None)
spike = kwargs.pop("spike", False)
ylabel = kwargs.pop("ylabel", None)
ds_rate = kwargs.pop("ds_rate", None)
if spike:
ylim = [0, 1.2]
ylabel = ylabel or "Spike Train"
axe.yaxis.set_ticklabels([" "])
if ds_rate is not None:
data_x = data_x[::ds_rate]
data_y = data_y[::ds_rate]
axe.plot(data_x, data_y, **kwargs)
if xlim:
axe.set_xlim(xlim)
if ylim:
axe.set_ylim(ylim)
if ylabel:
axe.set_ylabel(ylabel)
figw = kwargs.pop("figw", 5)
figh = kwargs.pop("figh", 2)
xlabel = kwargs.pop("xlabel", "Time, [s]")
num = len(args)
fig, axes = plt.subplots(num, 1, figsize=(figw, num * figh))
axes = np.atleast_1d(axes)
for i, (dataset, axe) in enumerate(zip(args, axes)):
axe.grid()
if i < num - 1:
axe.xaxis.set_ticklabels([])
if isinstance(dataset, np.ndarray):
param_list = [{}]
data_list = [dataset]
else:
param_list = dataset[1::2]
data_list = dataset[0::2]
has_legend = False
for data_y, subkwargs in zip(data_list, param_list):
for key, val in kwargs.items():
if not key in subkwargs:
subkwargs[key] = val
has_legend = has_legend or ("label" in subkwargs)
_plot(axe, data_x, data_y, **subkwargs)
if has_legend:
axe.legend()
axes[-1].set_xlabel(xlabel)
plt.tight_layout()
return fig, axes
[docs]def plot_spikes(
spikes: np.ndarray,
t: np.ndarray = None,
ax: plt.Axes = None,
markersize: int = None,
color: tp.Union[str, tp.Any] = "k",
) -> plt.Axes:
"""
Plot Spikes in raster format
Arguments:
spikes: the spike states in binary format, where 1 stands for a spike.
The shape of the spikes should either be (N_times, ) or (N_trials, N_times)
Keyword Arguments:
t: time axes for the spikes, use arange if not provided
ax: which axis to plot into, create one if not provided
markersize: size of raster
color: color for the raster. Any acceptable type of `matplotlib.pyplot.plot`'s
color argument is accepted.
Returns:
ax: the axis that the raster is plotted into
"""
spikes = np.atleast_2d(spikes)
if spikes.ndim != 2:
raise ValueError(
f"matrix need to be of ndim 2, (channels x time), got ndim={spikes.ndim}"
)
neu_idx, t_idx = np.nonzero(spikes)
if t is None:
t = np.arange(spikes.shape[1])
if ax is None:
fig = plt.gcf()
ax = fig.add_subplot()
try:
ax.plot(t[t_idx], neu_idx, "|", c=color, markersize=markersize)
except ValueError as e:
raise ValueError(
"Raster plot failed, likely an issue with color or markersize setting"
) from e
except IndexError as e:
raise ValueError(
"Raster plot failed, likely an issue with spikes and time vector mismatch"
) from e
except Exception as e:
raise ValueError("Raster plot failed due to unknown error") from e
ax.set_xlim([t.min(), t.max()])
return ax
[docs]def plot_mat(
mat: np.ndarray,
t: np.ndarray = None,
ax: plt.Axes = None,
cax=None,
vmin: float = None,
vmax: float = None,
cbar_kw: dict = None,
cmap: tp.Any = None,
) -> tp.Union[tp.Tuple[plt.Axes, tp.Any], plt.Axes]:
"""Plot Matrix with formatted time axes
Arguments:
mat: the matrix to be plotted, it should of shape (N, Time)
Keyword Arguments:
t: time axes for the spikes, use arange if not provided
ax: which axis to plot into, create one if not provided
cax: which axis to plot colorbar into
- if instance of axis, plot into that axis
- if is True, steal axis from `ax`
vmin: minimum value for the imshow
vmax: maximum value for the imshow
cbar_kw: keyword arguments to be passed into the colorbar creation
cmap: colormap to use
Returns:
ax: the axis that the raster is plotted into
cbar: colorbar object, only returned if cax is `True` or a `plt.Axes` instance
Example:
>>> dt, dur, start, stop = 1e-4, 2, 0.5, 1.0
>>> t = np.arange(0, dur, dt)
>>> amps = np.arange(0, 100, 10)
>>> wav = utils.generate_stimulus('step', dt, dur, (start, stop), amps)
>>> ax,cbar = plot_mat(wav, t=t, cax=True, vmin=10, vmax=100, cbar_kw={'label':'test'}, cmap=plt.cm.gnuplot)
>>> ax, = plot_mat(wav, t=t, cax=False, vmin=10, vmax=100, cbar_kw={'label':'test'}, cmap=plt.cm.gnuplot)
"""
mat = np.atleast_2d(mat)
if mat.ndim != 2:
raise ValueError(
f"matrix need to be of ndim 1 (N_time),or ndim 2 (N_trials x N_times), got ndim={mat.ndim}"
)
if t is None:
t = np.arange(mat.shape[1])
dt = t[1] - t[0]
@ticker.FuncFormatter
def major_formatter(x, pos):
return "{:.1f}".format(dt * x)
if ax is None:
fig = plt.gcf()
ax = fig.add_subplot()
cim = ax.imshow(
mat,
aspect="auto",
interpolation="none",
origin="lower",
vmin=vmin,
vmax=vmax,
cmap=cmap,
)
ax.xaxis.set_major_formatter(major_formatter)
if cax:
if cbar_kw is None:
cbar_kw = {}
if not isinstance(cax, plt.Axes):
cbar = plt.colorbar(cim, ax=ax, **cbar_kw)
else:
cbar = plt.colorbar(cim, cax, **cbar_kw)
return ax, cbar
else:
return (ax,)
[docs]def yyaxis(ax: plt.Axes, c: "color" = "red") -> plt.Axes:
"""Create A second axis with colored spine/ticks/label
Arguments:
ax: ax to create another y-axis from
Keyword Arguments:
c: Color
Returns:
ax2: axis that shares x with `ax`
"""
ax2 = ax.twinx()
ax2.spines["right"].set_color(c)
ax2.tick_params(axis="y", colors=c)
ax2.yaxis.label.set_color(c)
return ax2