Source code for Synaptipy.core.analysis.single_spike

# src/Synaptipy/core/analysis/single_spike.py
# -*- coding: utf-8 -*-
"""
Core Protocol Module 2: Single Spike Analysis.

Consolidates: Spike Detection, AP Characterisation (threshold, amplitude,
half-width, rise/decay times, AHP) and Phase Plane (dV/dt vs V) analysis.

All registry wrapper functions return::

    {
        "module_used": "single_spike",
        "metrics": { ... flat result keys ... }
    }
"""

import logging
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from scipy.signal import savgol_filter

from Synaptipy.core.analysis.passive_properties import apply_ljp_correction
from Synaptipy.core.analysis.registry import AnalysisRegistry
from Synaptipy.core.constants import DVDT_ARTIFACT_CEILING_VS, MIN_RISING_PHASE_MS
from Synaptipy.core.results import SingleSpikeResult, SpikeTrainResult

log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Spike Detection
# ---------------------------------------------------------------------------


[docs] def detect_spikes_threshold( # noqa: C901 data: np.ndarray, time: np.ndarray, threshold: float, refractory_samples: int, peak_search_window_samples: int = None, parameters: Dict[str, Any] = None, dvdt_threshold: float = 20.0, # Default: DVDT_THRESHOLD_VS (Bean 2007) ) -> SpikeTrainResult: """ Detect action potentials using a two-stage dV/dt-threshold crossing algorithm. Algorithm --------- 1. **First-derivative computation**: :func:`numpy.gradient` is applied to *data* with sample spacing ``dt = time[1] - time[0]`` (s), yielding dV/dt in mV s⁻¹. 2. **dV/dt crossing detection**: candidate spike onsets are identified as upward crossings of ``dvdt_threshold * 1000`` (mV s⁻¹). Each crossing is the sample where dV/dt transitions from strictly below to at-or-above the threshold. 3. **Refractory period enforcement**: candidate crossings separated by fewer than *refractory_samples* are suppressed, retaining only the first crossing in each refractory interval (greedy forward scan). 4. **Peak localisation**: for each accepted onset, the voltage maximum within the next *peak_search_window_samples* is found. The candidate is accepted as a spike only if ``data[peak_idx] >= threshold`` (mV). Parameters ---------- data : np.ndarray 1-D voltage array (mV). time : np.ndarray 1-D time array aligned with *data* (s). threshold : float Minimum voltage a candidate peak must reach to be accepted as a spike (mV). Guards against sub-threshold dV/dt transients. refractory_samples : int Minimum number of samples between successive accepted spike onsets. Convert from time: ``int(refractory_period_s * sampling_rate_hz)``. peak_search_window_samples : int, optional Number of samples to search forward from each onset crossing for the voltage peak. Defaults to *refractory_samples* when ``None``. parameters : dict, optional Arbitrary parameter dict stored verbatim in the returned :class:`~Synaptipy.core.results.SpikeTrainResult` for provenance. dvdt_threshold : float, optional dV/dt threshold for onset detection (V s⁻¹, default 20.0). Converted internally to mV s⁻¹ by multiplication with 1000. Returns ------- SpikeTrainResult Attributes populated on success: * ``value`` (int) – total spike count. * ``spike_times`` (np.ndarray) – peak times (s). * ``spike_indices`` (np.ndarray) – peak sample indices. * ``mean_frequency`` (float) – mean instantaneous firing rate ``(n_spikes - 1) / (t_last - t_first)`` (Hz); 0.0 for ≤ 1 spike. * ``is_valid`` (bool) – ``False`` when input arrays are malformed. """ if not isinstance(data, np.ndarray) or data.ndim != 1 or data.size < 2: return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Invalid data array", parameters=parameters or {} ) if not isinstance(time, np.ndarray) or time.shape != data.shape: return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Time and data mismatch", parameters=parameters or {}, ) if not isinstance(threshold, (int, float)): return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Threshold must be numeric", parameters=parameters or {}, ) if not isinstance(refractory_samples, int) or refractory_samples < 0: return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Invalid refractory period", parameters=parameters or {}, ) try: dt = time[1] - time[0] if len(time) > 1 else 1.0 # Apply 5 kHz low-pass filter from scipy.signal import butter, sosfiltfilt nyq = 0.5 / dt if 5000.0 < nyq: sos = butter(4, 5000.0, btype="low", output="sos", fs=1.0 / dt) data_filtered = sosfiltfilt(sos, data) else: data_filtered = data dvdt = np.gradient(data_filtered, dt) dvdt_thresh_mvs = dvdt_threshold * 1000.0 crossings = np.where((dvdt[:-1] < dvdt_thresh_mvs) & (dvdt[1:] >= dvdt_thresh_mvs))[0] + 1 if crossings.size == 0: return SpikeTrainResult( value=0, unit="spikes", spike_times=np.array([]), spike_indices=np.array([]), parameters=parameters or {}, ) if refractory_samples <= 0: valid_crossing_indices = crossings else: valid_crossings_list = [crossings[0]] last_crossing_idx = crossings[0] for idx in crossings[1:]: if (idx - last_crossing_idx) >= refractory_samples: valid_crossings_list.append(idx) last_crossing_idx = idx valid_crossing_indices = np.array(valid_crossings_list) if valid_crossing_indices.size == 0: return SpikeTrainResult( value=0, unit="spikes", spike_times=np.array([]), spike_indices=np.array([]), parameters=parameters or {}, ) peak_indices_list = [] if peak_search_window_samples is None: peak_search_window_samples = ( refractory_samples if refractory_samples > 0 else int(0.005 / (time[1] - time[0])) ) for crossing_idx in valid_crossing_indices: search_start = crossing_idx search_end = min(crossing_idx + peak_search_window_samples, len(data)) if search_start >= search_end: peak_idx = crossing_idx else: try: relative_peak_idx = np.argmax(data[search_start:search_end]) peak_idx = search_start + relative_peak_idx except ValueError: peak_idx = crossing_idx if data[peak_idx] >= threshold: peak_indices_list.append(peak_idx) peak_indices_arr = np.array(peak_indices_list).astype(int) peak_times_arr = time[peak_indices_arr] mean_freq = 0.0 if len(peak_times_arr) > 1: spike_span = peak_times_arr[-1] - peak_times_arr[0] if spike_span > 0: mean_freq = (len(peak_times_arr) - 1) / spike_span return SpikeTrainResult( value=len(peak_indices_arr), unit="spikes", spike_times=peak_times_arr, spike_indices=peak_indices_arr, mean_frequency=mean_freq, parameters=parameters or {}, ) except (ValueError, TypeError, KeyError, IndexError) as e: log.error(f"Error during spike detection: {e}", exc_info=True) return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message=str(e), parameters=parameters or {} )
# --------------------------------------------------------------------------- # AP Feature Extraction # ---------------------------------------------------------------------------
[docs] def calculate_spike_features( # noqa: C901 data: np.ndarray, time: np.ndarray, spike_indices: np.ndarray, dvdt_threshold: float = 20.0, ahp_window_sec: float = 0.05, onset_lookback: float = 0.01, fahp_window_ms: Tuple[float, float] = (1.0, 5.0), mahp_window_ms: Tuple[float, float] = (10.0, 50.0), ) -> List[Dict[str, Any]]: """ Calculate detailed features for each detected spike (vectorised NumPy). Returns list of SingleSpikeResult instances per spike, containing metrics like: ap_threshold, amplitude, half_width, rise_time_10_90, decay_time_90_10, fahp_depth, mahp_depth, ahp_duration_half, adp_amplitude, max_dvdt, min_dvdt. Methodology aligns with established electrophysiology standards: - AP threshold (onset) is strictly defined as the first point in the pre-spike lookback window where the discrete derivative dV/dt exceeds the specified ``dvdt_threshold`` (default 20 V/s). - ADP and AHP logic strictly follows the exact trough/peak finding methodology to ensure accurate feature extraction. Args: data: 1-D voltage array (mV). time: Corresponding time array (s). spike_indices: Array of sample indices for each spike peak. dvdt_threshold: The dV/dt threshold for AP onset (V/s). Default 20.0. ahp_window_sec: Duration of AHP/ADP search window (s). onset_lookback: Lookback window before each spike peak (s). fahp_window_ms: (start, end) of fast-AHP window after peak (ms). mahp_window_ms: (start, end) of medium-AHP window after peak (ms). Returns: A list of SingleSpikeResult objects. """ if spike_indices is None or spike_indices.size == 0: return [] spike_indices = np.asarray(spike_indices, dtype=int) n_spikes = len(spike_indices) n_data = len(data) if n_data < 2: return [] dt = time[1] - time[0] if dt <= 0: log.warning("Invalid time vector (dt <= 0). Cannot calculate features.") return [] # Apply 9.9 kHz low-pass Bessel filter for clean derivative calculation from scipy.signal import bessel, sosfiltfilt nyq = 0.5 / dt if 9900.0 < nyq: sos = bessel(4, 9900.0, btype="low", output="sos", fs=1.0 / dt) data_filtered = sosfiltfilt(sos, data) else: data_filtered = data dvdt = np.gradient(data_filtered, dt) lookback_samples = int(onset_lookback / dt) post_peak_samples = int(0.01 / dt) ahp_max_samples = int(ahp_window_sec / dt) # --- AP Threshold (onset) via dV/dt threshold crossing --- # The true physiological base of the action potential is standardly defined # as the point where the rising phase crosses a specific dV/dt threshold (e.g. 20 V/s). # The previous maximum curvature (d2vdt2) method was erroneously flagging the middle # of the upstroke due to extreme voltage acceleration. lookback_range = np.arange(-lookback_samples, 0) onset_window_indices = spike_indices[:, None] + lookback_range np.clip(onset_window_indices, 0, n_data - 1, out=onset_window_indices) onset_dvdt_windows = dvdt[onset_window_indices] # We use the explicit dvdt_threshold parameter (converted to mV/s) to find the onset. target_thresh_mvs = dvdt_threshold * 1000.0 # PHASE-PLANE BACKWARD SEARCH: # Instead of scanning backward from the peak (where dV/dt drops to 0), # we must scan backward from the point of MAXIMUM dV/dt to find the true onset crossing. max_dvdt_rel_idx = np.argmax(onset_dvdt_windows, axis=1) # Create a mask that is only valid BEFORE the max dV/dt point col_idxs = np.tile(np.arange(lookback_samples), (n_spikes, 1)) valid_search_mask = col_idxs <= max_dvdt_rel_idx[:, None] below_thresh_mask = (onset_dvdt_windows < target_thresh_mvs) & valid_search_mask has_crossing = np.any(below_thresh_mask, axis=1) # Reverse the mask to search backward from the max dV/dt rev_below = below_thresh_mask[:, ::-1] first_below_rev_idx = np.argmax(rev_below, axis=1) # Convert reversed index back to original window index first_below_idx = lookback_samples - 1 - first_below_rev_idx # The actual onset is the first sample >= threshold, which is one sample # after it falls below threshold. first_crossing_rel_idx = np.minimum(first_below_idx + 1, lookback_samples - 1) # If no crossing is found, fallback to a fixed window before peak. fallback_indices = np.maximum(0, spike_indices - int(0.002 / dt)) found_thresh_indices = onset_window_indices[np.arange(n_spikes), first_crossing_rel_idx] thresh_indices = np.where(has_crossing, found_thresh_indices, fallback_indices) ap_thresholds = data[thresh_indices] # Biological QC on detected thresholds: flag as NaN when the # per-spike peak rising rate exceeds 300 V/s (artifact ceiling) or the # threshold-to-peak rising phase is shorter than 0.1 ms (false detection). onset_max_dvdt = np.max(onset_dvdt_windows, axis=1) rising_phase_s = (spike_indices - thresh_indices) * dt at_edge = first_crossing_rel_idx == 0 artifact_flag = at_edge & ( (onset_max_dvdt > DVDT_ARTIFACT_CEILING_VS * 1000.0) | (rising_phase_s < MIN_RISING_PHASE_MS / 1000.0) ) ap_thresholds = np.where(artifact_flag, np.nan, ap_thresholds) peak_vals = data[spike_indices] amplitudes = peak_vals - ap_thresholds # --- Full waveform window --- full_window_len = lookback_samples + post_peak_samples full_window_range = np.arange(-lookback_samples, post_peak_samples) full_window_indices = spike_indices[:, None] + full_window_range np.clip(full_window_indices, 0, n_data - 1, out=full_window_indices) waveforms = data[full_window_indices] amp_50 = ap_thresholds + 0.5 * amplitudes amp_10 = ap_thresholds + 0.1 * amplitudes amp_90 = ap_thresholds + 0.9 * amplitudes half_widths = np.full(n_spikes, np.nan) rise_times = np.full(n_spikes, np.nan) decay_times = np.full(n_spikes, np.nan) rel_peak = lookback_samples col_indices = np.arange(full_window_len) is_pre_peak = col_indices < rel_peak is_post_peak = col_indices > rel_peak lev_50 = amp_50[:, None] idxs = np.tile(col_indices, (n_spikes, 1)) temp_mask = is_pre_peak & (waveforms <= lev_50) has_pre_50 = np.any(temp_mask, axis=1) masked_idxs_pre = np.where(temp_mask, idxs, -1) idx_rise_50_rel = np.max(masked_idxs_pre, axis=1) temp_mask_post = is_post_peak & (waveforms <= lev_50) has_post_50 = np.any(temp_mask_post, axis=1) masked_idxs_post = np.where(temp_mask_post, idxs, 999999) idx_fall_50_rel = np.min(masked_idxs_post, axis=1) valid_width = has_pre_50 & has_post_50 & (idx_rise_50_rel != -1) & (idx_fall_50_rel != 999999) safe_idx_rise_50 = np.clip(idx_rise_50_rel, 0, full_window_len - 2) safe_idx_fall_50 = np.clip(idx_fall_50_rel, 1, full_window_len - 1) # Linear interpolation for 50% rise y0_rise = waveforms[np.arange(n_spikes), safe_idx_rise_50] y1_rise = waveforms[np.arange(n_spikes), safe_idx_rise_50 + 1] dy_rise = y1_rise - y0_rise dy_rise[dy_rise == 0] = 1e-9 frac_rise = (amp_50 - y0_rise) / dy_rise x_rise = safe_idx_rise_50 + frac_rise # Linear interpolation for 50% fall y0_fall = waveforms[np.arange(n_spikes), safe_idx_fall_50 - 1] y1_fall = waveforms[np.arange(n_spikes), safe_idx_fall_50] dy_fall = y1_fall - y0_fall dy_fall[dy_fall == 0] = -1e-9 frac_fall = (amp_50 - y0_fall) / dy_fall x_fall = (safe_idx_fall_50 - 1) + frac_fall half_widths[valid_width] = (x_fall[valid_width] - x_rise[valid_width]) * dt * 1000.0 lev_10 = amp_10[:, None] lev_90 = amp_90[:, None] mask_10 = is_pre_peak & (waveforms <= lev_10) valid_10 = np.any(mask_10, axis=1) idx_10_rel = np.max(np.where(mask_10, idxs, -1), axis=1) mask_90 = is_pre_peak & (waveforms <= lev_90) valid_90 = np.any(mask_90, axis=1) idx_90_rel = np.max(np.where(mask_90, idxs, -1), axis=1) valid_rise = valid_10 & valid_90 & (idx_90_rel > idx_10_rel) safe_idx_10 = np.clip(idx_10_rel, 0, full_window_len - 2) safe_idx_90 = np.clip(idx_90_rel, 0, full_window_len - 2) # Linear interpolation for 10% y0_10 = waveforms[np.arange(n_spikes), safe_idx_10] y1_10 = waveforms[np.arange(n_spikes), safe_idx_10 + 1] dy_10 = y1_10 - y0_10 dy_10[dy_10 == 0] = 1e-9 frac_10 = (amp_10 - y0_10) / dy_10 x_10 = safe_idx_10 + frac_10 # Linear interpolation for 90% y0_90 = waveforms[np.arange(n_spikes), safe_idx_90] y1_90 = waveforms[np.arange(n_spikes), safe_idx_90 + 1] dy_90 = y1_90 - y0_90 dy_90[dy_90 == 0] = 1e-9 frac_90 = (amp_90 - y0_90) / dy_90 x_90 = safe_idx_90 + frac_90 rise_times[valid_rise] = (x_90[valid_rise] - x_10[valid_rise]) * dt * 1000.0 mask_dec_90 = is_post_peak & (waveforms <= lev_90) valid_dec_90 = np.any(mask_dec_90, axis=1) idx_dec_90_rel = np.min(np.where(mask_dec_90, idxs, 999999), axis=1) mask_dec_10 = is_post_peak & (waveforms <= lev_10) valid_dec_10 = np.any(mask_dec_10, axis=1) idx_dec_10_rel = np.min(np.where(mask_dec_10, idxs, 999999), axis=1) valid_decay = valid_dec_90 & valid_dec_10 & (idx_dec_10_rel > idx_dec_90_rel) safe_idx_dec_10 = np.clip(idx_dec_10_rel, 1, full_window_len - 1) safe_idx_dec_90 = np.clip(idx_dec_90_rel, 1, full_window_len - 1) # Linear interpolation for decay 90% y0_d90 = waveforms[np.arange(n_spikes), safe_idx_dec_90 - 1] y1_d90 = waveforms[np.arange(n_spikes), safe_idx_dec_90] dy_d90 = y1_d90 - y0_d90 dy_d90[dy_d90 == 0] = -1e-9 frac_d90 = (amp_90 - y0_d90) / dy_d90 x_d90 = (safe_idx_dec_90 - 1) + frac_d90 # Linear interpolation for decay 10% y0_d10 = waveforms[np.arange(n_spikes), safe_idx_dec_10 - 1] y1_d10 = waveforms[np.arange(n_spikes), safe_idx_dec_10] dy_d10 = y1_d10 - y0_d10 dy_d10[dy_d10 == 0] = -1e-9 frac_d10 = (amp_10 - y0_d10) / dy_d10 x_d10 = (safe_idx_dec_10 - 1) + frac_d10 decay_times[valid_decay] = (x_d10[valid_decay] - x_d90[valid_decay]) * dt * 1000.0 # --- AHP --- ahp_max_samples_per_spike = np.full(n_spikes, ahp_max_samples) if n_spikes > 1: dist_to_next = spike_indices[1:] - spike_indices[:-1] ahp_max_samples_per_spike[:-1] = np.minimum(ahp_max_samples, dist_to_next) ahp_range = np.arange(0, ahp_max_samples) ahp_indices = spike_indices[:, None] + ahp_range np.clip(ahp_indices, 0, n_data - 1, out=ahp_indices) ahp_waveforms = data[ahp_indices] col_idxs_ahp = np.tile(np.arange(ahp_max_samples), (n_spikes, 1)) valid_ahp_mask = col_idxs_ahp < ahp_max_samples_per_spike[:, None] window_length = int(0.005 / dt) if window_length % 2 == 0: window_length += 1 window_length = max(5, window_length) # Cap to trace width; if cap makes it even, step down to next odd so the # Savitzky-Golay constraint (window > polyorder=3) is preserved. n_cols = ahp_waveforms.shape[1] max_win = n_cols if n_cols % 2 == 1 else max(1, n_cols - 1) window_length = min(window_length, max_win) if window_length % 2 == 0: window_length = max(1, window_length - 1) if ahp_waveforms.shape[1] >= window_length and window_length >= 5: smoothed_ahp = savgol_filter(ahp_waveforms, window_length, 3, axis=1) else: smoothed_ahp = ahp_waveforms temp_ahp = smoothed_ahp.copy() temp_ahp[~valid_ahp_mask] = np.inf # DYNAMIC AHP BOUNDING: # Instead of unbounded argmin, use the first zero-crossing of the derivative # (negative to positive) indicating the end of repolarization. ahp_dvdt = dvdt[ahp_indices] ahp_dvdt[~valid_ahp_mask] = -1.0 # Ignore invalid regions crossing_mask = (ahp_dvdt[:, :-1] < 0) & (ahp_dvdt[:, 1:] >= 0) has_crossing = np.any(crossing_mask, axis=1) if crossing_mask.shape[1] > 0 else np.zeros(n_spikes, dtype=bool) first_crossing_idx = ( np.argmax(crossing_mask, axis=1) + 1 if crossing_mask.shape[1] > 0 else np.zeros(n_spikes, dtype=int) ) # Fallback to argmin if no clean zero-crossing is found ahp_min_rel_indices = np.where(has_crossing, first_crossing_idx, np.argmin(temp_ahp, axis=1)) mean_window = int(0.001 / dt) ahp_min_vals = np.zeros(n_spikes) for i in range(n_spikes): idx = ahp_min_rel_indices[i] start = max(0, idx - mean_window) end = min(ahp_max_samples_per_spike[i], idx + mean_window + 1) ahp_min_vals[i] = np.mean(ahp_waveforms[i, start:end]) rec_targets = ap_thresholds - 0.1 * amplitudes rec_target_bcast = rec_targets[:, None] is_after_min = col_idxs_ahp > ahp_min_rel_indices[:, None] is_recovered = ahp_waveforms >= rec_target_bcast valid_recovery = is_after_min & is_recovered & valid_ahp_mask has_recovery = np.any(valid_recovery, axis=1) rec_rel_indices = np.where(has_recovery, np.argmax(valid_recovery, axis=1), ahp_max_samples) thresh_bcast = ap_thresholds[:, None] is_below_thresh_ahp = ahp_waveforms < thresh_bcast has_ap_end = np.any(is_below_thresh_ahp, axis=1) ap_end_rel_indices = np.where(has_ap_end, np.argmax(is_below_thresh_ahp, axis=1), 0) ahp_durations = np.full(n_spikes, np.nan) valid_ahp_dur = has_recovery & has_ap_end & (rec_rel_indices > ap_end_rel_indices) ahp_durations[valid_ahp_dur] = (rec_rel_indices[valid_ahp_dur] - ap_end_rel_indices[valid_ahp_dur]) * dt * 1000.0 # --- ADP --- adp_amplitudes = np.full(n_spikes, np.nan) if ahp_max_samples > 2: val_mid = ahp_waveforms[:, 1:-1] val_left = ahp_waveforms[:, :-2] val_right = ahp_waveforms[:, 2:] # Find all local minima (troughs) is_local_min_inner = (val_mid < val_left) & (val_mid < val_right) is_local_min = np.pad(is_local_min_inner, ((0, 0), (1, 1)), mode="constant", constant_values=False) col_idxs2 = np.tile(np.arange(ahp_max_samples), (n_spikes, 1)) # The fast trough is the FIRST local minimum valid_min_mask = is_local_min & (col_idxs2 < ahp_max_samples_per_spike[:, None]) has_trough = np.any(valid_min_mask, axis=1) first_trough_idx = np.argmax(valid_min_mask, axis=1) # The ADP peak is the FIRST local maximum after the fast trough is_local_max_inner = (val_mid > val_left) & (val_mid > val_right) is_local_max = np.pad(is_local_max_inner, ((0, 0), (1, 1)), mode="constant", constant_values=False) valid_max_mask = ( is_local_max & (col_idxs2 > first_trough_idx[:, None]) & (col_idxs2 < ahp_max_samples_per_spike[:, None]) ) # Find highest valid local maximum between fast trough and next spike / end of window temp_vals = ahp_waveforms.copy() temp_vals[~valid_max_mask] = -np.inf adp_peaks = np.max(temp_vals, axis=1) has_adp = np.any(valid_max_mask, axis=1) & ~np.isinf(adp_peaks) fast_trough_vals = ahp_waveforms[np.arange(n_spikes), first_trough_idx] calced_adps = adp_peaks - fast_trough_vals adp_amplitudes = np.where(has_trough & has_adp, calced_adps, np.nan) # --- fAHP and mAHP (separate physiological windows) --- # fAHP: fast AHP (default 1-5 ms post-peak): Na+ channel-mediated repolarisation overshoot # mAHP: medium AHP (default 10-50 ms post-peak): K+ channel-mediated hyperpolarisation fahp_start = max(1, int(fahp_window_ms[0] / 1000.0 / dt)) fahp_end = max(fahp_start + 1, int(fahp_window_ms[1] / 1000.0 / dt)) mahp_start = max(1, int(mahp_window_ms[0] / 1000.0 / dt)) mahp_end = max(mahp_start + 1, int(mahp_window_ms[1] / 1000.0 / dt)) def _window_min(start_s: int, end_s: int) -> np.ndarray: """Return per-spike min voltage in [peak+start_s, peak+end_s).""" w_len = end_s - start_s if w_len <= 0: return np.full(n_spikes, np.nan) w_range = np.arange(start_s, end_s) w_indices = spike_indices[:, None] + w_range np.clip(w_indices, 0, n_data - 1, out=w_indices) return np.min(data[w_indices], axis=1) fahp_min_vals = _window_min(fahp_start, fahp_end) mahp_min_vals = _window_min(mahp_start, mahp_end) fahp_depths = ap_thresholds - fahp_min_vals mahp_depths = ap_thresholds - mahp_min_vals # --- max/min dV/dt --- raw_dvdt = np.gradient(waveforms, axis=1) / dt / 1000.0 # Apply a dynamic sampling-rate dependent rolling window (standard ~0.1 ms) # to smooth the derivative, matching standard smoothing behavior # and preventing single-sample noise spikes from inflating the rate. window_ms = 0.1 window_size = max(3, int(window_ms / (dt * 1000.0))) if window_size % 2 == 0: window_size += 1 kernel = np.ones(window_size) / window_size from scipy.ndimage import convolve1d full_dvdt = convolve1d(raw_dvdt, kernel, axis=1, mode="nearest") # Calculate column indices of thresholds rel_thresh_indices = thresh_indices - (spike_indices - lookback_samples) # max dV/dt strictly between threshold and peak valid_rise_mask = is_pre_peak[None, :] & (col_indices[None, :] >= rel_thresh_indices[:, None]) pre_peak_dvdt = np.where(valid_rise_mask, full_dvdt, -np.inf) # min dV/dt after peak post_peak_dvdt = np.where(is_post_peak[None, :], full_dvdt, np.inf) max_dvdts = np.max(pre_peak_dvdt, axis=1) min_dvdts = np.min(post_peak_dvdt, axis=1) # --- New Active Features --- ap_delays = time[thresh_indices] ahp_times = np.zeros(n_spikes) for i in range(n_spikes): ahp_times[i] = time[ahp_indices[i, ahp_min_rel_indices[i]]] trough_vs = ahp_min_vals upstroke_downstroke_ratios = np.full(n_spikes, np.nan) valid_ratio = min_dvdts != 0 upstroke_downstroke_ratios[valid_ratio] = max_dvdts[valid_ratio] / np.abs(min_dvdts[valid_ratio]) phase_plane_areas = np.full(n_spikes, np.nan) ap_widths_arbitrary = [None] * n_spikes for i in range(n_spikes): start_idx = thresh_indices[i] end_idx = spike_indices[i] + ap_end_rel_indices[i] # Calculate Phase Plane Area (Shoelace formula) if start_idx < end_idx: # Find the local slice within the full_dvdt window rel_start = (start_idx - spike_indices[i]) + lookback_samples rel_end = (end_idx - spike_indices[i]) + lookback_samples # Ensure boundaries are within full_dvdt shape rel_start = max(0, min(rel_start, full_window_len - 1)) rel_end = max(0, min(rel_end, full_window_len)) if rel_end > rel_start + 2: x = data[start_idx : start_idx + (rel_end - rel_start)] y = full_dvdt[i, rel_start:rel_end] if len(x) == len(y): area = 0.5 * np.abs(np.dot(x[:-1], y[1:]) - np.dot(x[1:], y[:-1])) phase_plane_areas[i] = area features_list = [] for i in range(n_spikes): features_list.append( SingleSpikeResult( value=None, unit="mV", ap_threshold=float(ap_thresholds[i]), amplitude=float(amplitudes[i]), half_width=float(half_widths[i]), rise_time_10_90=float(rise_times[i]), decay_time_90_10=float(decay_times[i]), fahp_depth=float(fahp_depths[i]), mahp_depth=float(mahp_depths[i]), ahp_duration_half=float(ahp_durations[i]), adp_amplitude=float(adp_amplitudes[i]), max_dvdt=float(max_dvdts[i]), min_dvdt=float(min_dvdts[i]), absolute_peak_mv=float(peak_vals[i]), overshoot_mv=float(max(0.0, peak_vals[i])), ap_delay=float(ap_delays[i]), ap_width_arbitrary=ap_widths_arbitrary[i], ahp_time=float(ahp_times[i]), upstroke_downstroke_ratio=float(upstroke_downstroke_ratios[i]), phase_plane_area=float(phase_plane_areas[i]), trough_v=float(trough_vs[i]), ) ) return features_list
[docs] def calculate_isi(spike_times: np.ndarray) -> np.ndarray: """Return inter-spike intervals from spike_times array.""" if len(spike_times) < 2: return np.array([]) return np.diff(spike_times)
[docs] def analyze_multi_sweep_spikes( data_trials: List[np.ndarray], time_vector: np.ndarray, threshold: float, refractory_samples: int, dvdt_threshold: float = 20.0, ) -> List[SpikeTrainResult]: """Detect spikes across multiple sweeps.""" results = [] for i, trial_data in enumerate(data_trials): try: result = detect_spikes_threshold( trial_data, time_vector, threshold, refractory_samples, dvdt_threshold=dvdt_threshold ) result.metadata["sweep_index"] = i results.append(result) except (ValueError, TypeError, KeyError, IndexError) as e: log.error(f"Error analyzing sweep {i}: {e}") error_result = SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message=f"Sweep {i}: {str(e)}" ) error_result.metadata["sweep_index"] = i results.append(error_result) return results
# --------------------------------------------------------------------------- # Phase Plane (dV/dt vs V) # ---------------------------------------------------------------------------
[docs] def calculate_dvdt(voltage: np.ndarray, sampling_rate: float, sigma_ms: float = 0.1) -> np.ndarray: """ Calculate dV/dt (V/s) with optional Savitzky-Golay smoothing. Computes the raw derivative first, then applies a Savitzky-Golay filter (polynomial order 3) directly to the derivative array. This preserves the true max dV/dt better than pre-smoothing the voltage with a Gaussian, which attenuates the sharp upstroke of action potentials. Args: voltage: 1D voltage array (mV). sampling_rate: Sampling rate (Hz). sigma_ms: Smoothing window (ms). The SG window length is derived as ``max(5, int(sigma_ms / 1000 * sampling_rate))``, rounded up to the next odd integer. Set to 0 for no smoothing. Returns: 1D array of dV/dt in V/s. """ dt = 1.0 / sampling_rate dvdt = np.gradient(voltage, dt) / 1000.0 # mV/s -> V/s if sigma_ms > 0 and len(dvdt) >= 5: # Dynamic window length derived from sigma_ms and sampling rate (must be odd >= 5) window_samples = max(5, int(sigma_ms / 1000.0 * sampling_rate)) if window_samples % 2 == 0: window_samples += 1 # Cap at signal length (savgol_filter requires window <= len) window_samples = min(window_samples, len(dvdt) if len(dvdt) % 2 == 1 else len(dvdt) - 1) if window_samples >= 5: dvdt = savgol_filter(dvdt, window_samples, 3) return dvdt
[docs] def get_phase_plane_trajectory( voltage: np.ndarray, sampling_rate: float, sigma_ms: float = 0.1 ) -> Tuple[np.ndarray, np.ndarray]: """Return (voltage, dvdt) phase-plane trajectory.""" dvdt = calculate_dvdt(voltage, sampling_rate, sigma_ms) return voltage, dvdt
[docs] def detect_threshold_kink( voltage: np.ndarray, sampling_rate: float, dvdt_threshold: float = 20.0, kink_slope: float = 10.0, search_window_ms: float = 5.0, peak_indices: Optional[np.ndarray] = None, ) -> np.ndarray: """ Detect AP threshold using the dV/dt kink method. Returns array of threshold indices. """ if peak_indices is None: res = detect_spikes_threshold( voltage, np.arange(len(voltage)) / sampling_rate, -20.0, int(0.002 * sampling_rate) ) peak_indices = res.spike_indices dvdt = calculate_dvdt(voltage, sampling_rate, sigma_ms=0.1) threshold_indices = [] search_samples = int((search_window_ms / 1000.0) * sampling_rate) for peak_idx in peak_indices: start_search = max(0, peak_idx - search_samples) dvdt_slice = dvdt[start_search:peak_idx] crossings = np.where(dvdt_slice > dvdt_threshold)[0] if crossings.size > 0: thresh_idx = start_search + crossings[0] else: thresh_idx = max(0, peak_idx - int(0.001 * sampling_rate)) threshold_indices.append(thresh_idx) return np.array(threshold_indices)
# --------------------------------------------------------------------------- # Registry Wrappers # ---------------------------------------------------------------------------
[docs] @AnalysisRegistry.register( "spike_detection", label="Spike Detection", ui_params=[ { "name": "threshold", "label": "Threshold (mV):", "type": "float", "default": -20.0, "min": -1e9, "max": 1e9, "decimals": 4, }, { "name": "refractory_period", "label": "Refractory (s):", "type": "float", "default": 0.002, "min": 0.0, "max": 1e9, "decimals": 4, }, { "name": "peak_search_window", "label": "Peak Search (s):", "type": "float", "default": 0.005, "min": 0.0, "max": 1.0, "decimals": 4, }, { "name": "dvdt_threshold", "label": "dV/dt Thresh (V/s):", "type": "float", "default": 20.0, "min": 0.0, "max": 1e6, "decimals": 1, }, { "name": "ahp_window", "label": "AHP Window (s):", "type": "float", "default": 0.05, "min": 0.0, "max": 10.0, "decimals": 3, }, { "name": "onset_lookback", "label": "Onset Lookback (s):", "type": "float", "default": 0.01, "min": 0.0, "max": 0.1, "decimals": 3, }, { "name": "ljp_correction_mv", "label": "LJP Correction (mV):", "type": "float", "default": 0.0, "min": -100.0, "max": 100.0, "decimals": 2, "tooltip": "Liquid Junction Potential in mV. V_true = V_recorded - LJP.", }, ], plots=[ {"type": "hlines", "data": ["threshold"], "color": "r", "styles": ["dash"]}, {"type": "markers", "x": "spike_times", "y": "spike_voltages", "color": "r"}, ], ) def run_spike_detection_wrapper( data: np.ndarray, time: np.ndarray, sampling_rate: float, threshold: float = -20.0, refractory_period: float = 0.002, peak_search_window: float = 0.005, dvdt_threshold: float = 20.0, ahp_window: float = 0.05, onset_lookback: float = 0.01, **kwargs, ) -> Dict[str, Any]: """Wrapper for spike detection. Returns namespaced schema.""" try: ljp_mv = float(kwargs.get("ljp_correction_mv", 0.0)) data = apply_ljp_correction(data, ljp_mv) refractory_samples = int(refractory_period * sampling_rate) peak_window_samples = int(peak_search_window * sampling_rate) params = { "threshold": threshold, "refractory_period": refractory_period, "peak_search_window": peak_search_window, "dvdt_threshold": dvdt_threshold, "ahp_window": ahp_window, "onset_lookback": onset_lookback, } result = detect_spikes_threshold( data, time, threshold, refractory_samples, peak_search_window_samples=peak_window_samples, parameters=params, dvdt_threshold=dvdt_threshold, ) if result.is_valid: features_list = calculate_spike_features( data, time, result.spike_indices, dvdt_threshold=dvdt_threshold, ahp_window_sec=ahp_window, onset_lookback=onset_lookback, ) stats: Dict[str, Any] = {} if features_list: import dataclasses # Collect attributes dynamically from the first valid result first_feat = features_list[0] feat_dict = dataclasses.asdict(first_feat) valid_keys = [ k for k in feat_dict.keys() if k not in [ "value", "unit", "is_valid", "error_message", "quality_flags", "confidence", "metadata", "parameters", ] ] for key in valid_keys: values = [ getattr(f, key) for f in features_list if getattr(f, key) is not None and not np.isnan(getattr(f, key)) ] if values: stats[f"{key}_mean"] = float(np.mean(values)) stats[f"{key}_std"] = float(np.std(values)) else: stats[f"{key}_mean"] = np.nan stats[f"{key}_std"] = np.nan v_data = ( data[result.spike_indices] if result.spike_indices is not None and len(result.spike_indices) > 0 else np.array([]) ) metrics: Dict[str, Any] = { "spike_count": len(result.spike_indices) if result.spike_indices is not None else 0, "mean_freq_hz": result.mean_frequency if result.mean_frequency is not None else 0.0, "spike_times": result.spike_times, "spike_indices": result.spike_indices, "spike_voltages": v_data, "threshold": threshold, "parameters": params, } metrics.update(stats) else: metrics = { "spike_count": 0, "mean_freq_hz": 0.0, "threshold": threshold, "spike_error": result.error_message or "Unknown error", "parameters": params, } return {"module_used": "single_spike", "metrics": metrics} except (ValueError, TypeError, KeyError, IndexError) as e: log.error(f"Error in run_spike_detection_wrapper: {e}", exc_info=True) return { "module_used": "single_spike", "metrics": {"spike_count": 0, "mean_freq_hz": 0.0, "spike_error": str(e)}, }
[docs] @AnalysisRegistry.register( "phase_plane_analysis", label="Phase Plane", plots=[ {"name": "Trace", "type": "trace"}, {"type": "popup_phase", "title": "Phase Plane"}, ], ui_params=[ { "name": "sigma_ms", "label": "Smoothing (ms):", "type": "float", "default": 0.1, "min": 0.0, "max": 1e9, "decimals": 4, }, { "name": "dvdt_threshold", "label": "dV/dt Thresh (V/s):", "type": "float", "default": 20.0, "min": 0.0, "max": 1e9, "decimals": 4, }, { "name": "spike_threshold", "label": "Spike Detect Thresh (mV):", "type": "float", "default": -20.0, "min": -1000.0, "max": 1000.0, "decimals": 2, }, {"name": "kink_slope", "label": "Kink Slope:", "type": "float", "default": 10.0, "hidden": True}, { "name": "search_window_ms", "label": "Search Window (ms):", "type": "float", "default": 5.0, "min": 0.1, "max": 100.0, "decimals": 2, }, { "name": "ljp_correction_mv", "label": "LJP Correction (mV):", "type": "float", "default": 0.0, "min": -100.0, "max": 100.0, "decimals": 2, "tooltip": "Liquid Junction Potential in mV. V_true = V_recorded - LJP.", }, ], ) def phase_plane_analysis_wrapper( voltage: np.ndarray, time: np.ndarray, sampling_rate: float, sigma_ms: float = 0.1, dvdt_threshold: float = 20.0, **kwargs, ) -> Dict[str, Any]: """Wrapper for Phase Plane analysis. Returns namespaced schema.""" spike_threshold = kwargs.get("spike_threshold", -20.0) search_window_ms = kwargs.get("search_window_ms", 5.0) kink_slope = kwargs.get("kink_slope", 10.0) ljp_mv = float(kwargs.get("ljp_correction_mv", 0.0)) voltage = apply_ljp_correction(voltage, ljp_mv) v, dvdt = get_phase_plane_trajectory(voltage, sampling_rate, sigma_ms) spike_res = detect_spikes_threshold(voltage, time, spike_threshold, int(0.002 * sampling_rate)) thresh_indices = detect_threshold_kink( voltage, sampling_rate, dvdt_threshold=dvdt_threshold, kink_slope=kink_slope, search_window_ms=search_window_ms, peak_indices=spike_res.spike_indices, ) threshold_vals = voltage[thresh_indices] if thresh_indices.size > 0 else [] metrics = { "voltage": v, "dvdt": dvdt, "threshold_indices": thresh_indices, "threshold_vals": threshold_vals, "threshold_v": float(np.mean(threshold_vals)) if len(threshold_vals) > 0 else np.nan, "threshold_dvdt": float(dvdt_threshold), "max_dvdt": float(np.max(dvdt)) if len(dvdt) > 0 else 0.0, "threshold_mean": float(np.mean(threshold_vals)) if len(threshold_vals) > 0 else np.nan, } return {"module_used": "single_spike", "metrics": metrics}
# Keep the original function name as an alias so existing code and tests still work phase_plane_analysis = phase_plane_analysis_wrapper # --------------------------------------------------------------------------- # Module-level tab aggregator # ---------------------------------------------------------------------------
[docs] @AnalysisRegistry.register( "single_spike", label="Spike Analysis", method_selector={ "Spike Detection": "spike_detection", "Phase Plane": "phase_plane_analysis", }, ui_params=[], plots=[], ) def single_spike_module(**kwargs): """Module-level aggregator tab for single-spike analyses.""" return {}