# src/Synaptipy/core/analysis/evoked_responses.py
# -*- coding: utf-8 -*-
"""
Core Protocol Module 5: Evoked Responses.
Consolidates optogenetic stimulus synchronization (TTL-gated latency,
probability, jitter analysis) from optogenetics.py.
All registry wrapper functions return::
{
"module_used": "evoked_responses",
"metrics": { ... flat result keys ... }
}
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from scipy.optimize import curve_fit
from Synaptipy.core.analysis.registry import AnalysisRegistry
from Synaptipy.core.analysis.single_spike import detect_spikes_threshold
from Synaptipy.core.analysis.synaptic_events import detect_events_template, detect_events_threshold
from Synaptipy.core.results import AnalysisResult
from Synaptipy.core.signal_processor import find_artifact_windows
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Result dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OptoSyncResult(AnalysisResult):
"""Result object for optogenetic synchronization analysis."""
optical_latency_ms: Optional[float] = None
response_probability: Optional[float] = None
spike_jitter_ms: Optional[float] = None
stimulus_count: int = 0
success_count: int = 0
failure_count: int = 0
stimulus_onsets: Optional[np.ndarray] = None
stimulus_offsets: Optional[np.ndarray] = None
responding_spikes: List[List[float]] = field(default_factory=list)
parameters: Dict[str, Any] = field(default_factory=dict)
def __repr__(self):
if self.is_valid:
lat = f"{self.optical_latency_ms:.2f}" if self.optical_latency_ms is not None else "N/A"
prob = f"{self.response_probability:.2f}" if self.response_probability is not None else "N/A"
jit = f"{self.spike_jitter_ms:.2f}" if self.spike_jitter_ms is not None else "N/A"
return (
f"OptoSyncResult(Latency={lat} ms, Prob={prob}, "
f"Success={self.success_count}/{self.stimulus_count}, "
f"Jitter={jit} ms)"
)
return f"OptoSyncResult(Error: {self.error_message})"
# ---------------------------------------------------------------------------
# TTL Extraction
# ---------------------------------------------------------------------------
def _find_spikes_in_window(spikes: np.ndarray, t_start: float, t_end: float) -> np.ndarray:
"""Vectorised helper: return spikes within [t_start, t_end]."""
if spikes.size == 0:
return np.array([])
mask = (spikes >= t_start) & (spikes <= t_end)
return spikes[mask]
# ---------------------------------------------------------------------------
# Core Analysis
# ---------------------------------------------------------------------------
[docs]
def calculate_optogenetic_sync(
ttl_data: np.ndarray,
action_potential_times: np.ndarray,
time: np.ndarray,
ttl_threshold: float = 2.5,
response_window_ms: float = 20.0,
) -> OptoSyncResult:
"""
Correlate TTL stimuli with action potential times.
Args:
ttl_data: Digital signal data trace.
action_potential_times: Pre-calculated spike/event times (seconds).
time: Timestamps of the trace.
ttl_threshold: Voltage threshold for TTL edge detection.
response_window_ms: Search window for APs after stimulus onset (ms).
Returns:
OptoSyncResult.
"""
if ttl_data.size == 0:
return OptoSyncResult(value=None, unit="", is_valid=False, error_message="Empty TTL Data")
onsets, offsets = extract_ttl_epochs(ttl_data, time, ttl_threshold)
stimulus_count = len(onsets)
if stimulus_count == 0:
return OptoSyncResult(
value=None,
unit="",
is_valid=False,
error_message="No TTL stimuli detected above threshold",
)
window_s = response_window_ms / 1000.0
latencies = []
responding_spikes = []
response_count = 0
for onset in onsets:
valid_spikes = _find_spikes_in_window(action_potential_times, onset, onset + window_s)
responding_spikes.append(valid_spikes.tolist())
if valid_spikes.size > 0:
response_count += 1
latencies.append((valid_spikes[0] - onset) * 1000.0)
failure_count = stimulus_count - response_count
# Latency and jitter are computed only over *successful* trials to prevent
# NaN propagation from failure trials.
if response_count > 0:
optical_latency_ms = float(np.mean(latencies))
spike_jitter_ms = float(np.std(latencies)) if len(latencies) > 1 else 0.0
response_probability = float(response_count / stimulus_count)
else:
optical_latency_ms = np.nan
spike_jitter_ms = np.nan
response_probability = 0.0
return OptoSyncResult(
value=optical_latency_ms,
unit="ms",
is_valid=True,
optical_latency_ms=optical_latency_ms,
response_probability=response_probability,
spike_jitter_ms=spike_jitter_ms,
stimulus_count=stimulus_count,
success_count=response_count,
failure_count=failure_count,
stimulus_onsets=onsets,
stimulus_offsets=offsets,
responding_spikes=responding_spikes,
parameters={"ttl_threshold": ttl_threshold, "response_window_ms": response_window_ms},
)
# ---------------------------------------------------------------------------
# Shared scatter-marker helper
# ---------------------------------------------------------------------------
def _peak_pos_s(
data: np.ndarray,
time: np.ndarray,
onset_s: float,
polarity: str,
blank_s: float,
win_s: float,
) -> Tuple[float, float]:
"""Return ``(peak_time, peak_raw_value)`` for scatter-plot overlays.
Searches a post-stimulus window (after artifact blanking) for the extremum
matching *polarity* and returns its time and raw signal value. On failure
(e.g. window out of range) falls back to the onset time and the signal
value at that sample.
Args:
data: 1-D signal trace.
time: 1-D time vector (same length as *data*).
onset_s: Stimulus onset time in seconds.
polarity: ``"negative"`` or ``"positive"``.
blank_s: Artifact-blanking duration in seconds.
win_s: Response-search window duration in seconds.
Returns:
Tuple of ``(peak_time_s, peak_raw_value)``.
"""
i0 = int(np.searchsorted(time, onset_s + blank_s))
i1 = min(int(np.searchsorted(time, onset_s + win_s)) + 1, len(data))
if i1 <= i0:
fallback = min(int(np.searchsorted(time, onset_s)), len(data) - 1)
return onset_s, float(data[fallback])
seg = data[i0:i1]
off = int(np.argmin(seg) if polarity == "negative" else np.argmax(seg))
return float(time[i0 + off]), float(seg[off])
# ---------------------------------------------------------------------------
# Paired-Pulse Ratio with Residual Subtraction
# ---------------------------------------------------------------------------
[docs]
def calculate_paired_pulse_ratio( # noqa: C901
data: np.ndarray,
time: np.ndarray,
stim1_onset_s: float,
stim2_onset_s: float,
response_window_ms: float = 20.0,
baseline_window_ms: float = 5.0,
fit_decay_from_ms: float = 5.0,
fit_decay_window_ms: float = 30.0,
polarity: str = "negative",
artifact_blanking_ms: float = 1.0,
) -> Dict[str, Any]:
"""Calculate Paired-Pulse Ratio with residual decay subtraction.
Without subtracting the residual exponential decay of the first event
under the second stimulus window, the measured amplitude of the second
response is artificially inflated (facilitation) or deflated (depression),
yielding biologically invalid PPR values.
Algorithm:
1. Measure amplitude of response 1 (R1) relative to its local pre-stimulus
baseline.
2. Fit a mono-exponential decay to the *tail* of R1 (from
``fit_decay_from_ms`` to ``fit_decay_window_ms`` after stim1_onset).
3. Extrapolate the decay curve to estimate the residual baseline level at
stim2_onset.
4. Measure amplitude of response 2 (R2_raw) relative to its own pre-stimulus
sample.
5. Subtract the residual decay value from R2_raw to obtain R2_corrected.
6. Return ``paired_pulse_ratio = R2_corrected / R1``.
Args:
data: 1-D voltage/current array (mV or pA).
time: 1-D time array (s).
stim1_onset_s: Time of first stimulus onset (s).
stim2_onset_s: Time of second stimulus onset (s).
response_window_ms: Duration after each stimulus to search for peak (ms).
baseline_window_ms: Pre-stimulus baseline window (ms) to compute local
baseline for each response.
fit_decay_from_ms: Offset from stim1_onset to start fitting decay (ms).
Should be after the initial transient.
fit_decay_window_ms: Window duration for decay fit (ms).
polarity: ``"negative"`` (inward/downward events, e.g. EPSCs) or
``"positive"``.
artifact_blanking_ms: Duration (ms) after each stimulus onset to ignore
when searching for the peak response (default 1.0). Prevents the
stimulus shock-wave artefact from being identified as the biological
response peak.
Returns:
Dict with keys:
- ``r1_amplitude`` – amplitude of first response (baseline-subtracted)
- ``r2_amplitude_raw`` – raw amplitude of second response
- ``r2_amplitude_corrected`` – R2 after subtracting residual decay
- ``residual_at_stim2`` – estimated residual baseline at stim2_onset
- ``paired_pulse_ratio`` – R2_corrected / R1
- ``decay_tau_ms`` – time constant of first event decay (ms)
- ``ppr_error`` – None on success; error string on failure
"""
out: Dict[str, Any] = {
"r1_amplitude": None,
"r2_amplitude_raw": None,
"r2_amplitude_corrected": None,
"residual_at_stim2": None,
"paired_pulse_ratio": None,
"decay_tau_ms": None,
"ppr_error": None,
}
if data.size < 2 or time.shape != data.shape:
out["ppr_error"] = "Invalid data or time array"
return out
fs = 1.0 / float(time[1] - time[0]) # noqa: F841
def _nearest_idx(t: float) -> int:
return int(np.searchsorted(time, t))
def _local_baseline(stim_onset_s: float) -> float:
bl_start_s = stim_onset_s - baseline_window_ms / 1000.0
bl_start_s = max(bl_start_s, float(time[0]))
i0 = _nearest_idx(bl_start_s)
i1 = _nearest_idx(stim_onset_s)
i1 = max(i0 + 1, i1)
segment = data[i0:i1]
return float(np.mean(segment)) if segment.size > 0 else float(data[_nearest_idx(stim_onset_s)])
def _response_peak(stim_onset_s: float, baseline: float) -> Tuple[float, float]:
"""Return (peak_amplitude, raw_peak_value) relative to baseline.
Data within ``artifact_blanking_ms`` of the stimulus onset are excluded
so that the stimulus artefact is never mistaken for the biological peak.
"""
blank_s = artifact_blanking_ms / 1000.0
win_start = _nearest_idx(stim_onset_s + blank_s)
win_end = min(_nearest_idx(stim_onset_s + response_window_ms / 1000.0) + 1, len(data))
if win_end <= win_start:
return 0.0, baseline
segment = data[win_start:win_end]
if polarity == "negative":
peak_raw = float(np.min(segment))
return baseline - peak_raw, peak_raw
else:
peak_raw = float(np.max(segment))
return peak_raw - baseline, peak_raw
# --- R1 ---
bl1 = _local_baseline(stim1_onset_s)
r1_amp, _ = _response_peak(stim1_onset_s, bl1)
out["r1_amplitude"] = r1_amp
if r1_amp <= 0:
out["ppr_error"] = "R1 amplitude <= 0; cannot compute PPR"
return out
# --- Exponential decay fit on R1 tail ---
def _mono_exp(t: np.ndarray, a: float, tau: float, c: float) -> np.ndarray:
return a * np.exp(-t / tau) + c
fit_start_s = stim1_onset_s + fit_decay_from_ms / 1000.0
fit_end_s = stim1_onset_s + (fit_decay_from_ms + fit_decay_window_ms) / 1000.0
fit_end_s = min(fit_end_s, stim2_onset_s)
i_fit0 = _nearest_idx(fit_start_s)
i_fit1 = _nearest_idx(fit_end_s)
if i_fit1 - i_fit0 < 4:
# Fallback: no residual correction
bl2 = _local_baseline(stim2_onset_s)
r2_amp_raw, _ = _response_peak(stim2_onset_s, bl2)
out["r2_amplitude_raw"] = r2_amp_raw
out["r2_amplitude_corrected"] = r2_amp_raw
out["residual_at_stim2"] = 0.0
out["decay_tau_ms"] = None
if r1_amp > 0:
out["paired_pulse_ratio"] = r2_amp_raw / r1_amp
out["ppr_error"] = "Decay fit window too short; no residual correction applied"
return out
t_fit = (time[i_fit0:i_fit1] - time[i_fit0]) * 1000.0 # ms
y_fit = data[i_fit0:i_fit1]
# Amplitude at fit start relative to long-run asymptote (approx bl1)
a0 = float(y_fit[0] - bl1) if polarity == "positive" else float(bl1 - y_fit[0])
a0 = max(a0, 1e-6)
tau0 = max(1.0, float(t_fit[-1]) / 3.0)
residual_at_stim2 = 0.0
tau_ms = None
def _bi_exp(t: np.ndarray, a_f: float, tau_f: float, a_s: float, tau_s: float, c: float) -> np.ndarray:
return a_f * np.exp(-t / tau_f) + a_s * np.exp(-t / tau_s) + c
try:
t_at_stim2_ms = (stim2_onset_s - time[i_fit0]) * 1000.0
t_fit_abs = time[i_fit0:i_fit1]
# Strict amplitude bound: ±3x R1 amplitude prevents parameter explosion.
amp_bound = max(a0 * 3.0, abs(r1_amp) * 2.0, 1e-6)
_fit_func = None
_popt = None
# ── Attempt bi-exponential fit (requires >= 8 samples for 5 params) ──
if len(t_fit) >= 8:
try:
if polarity == "negative":
bi_p0 = [-a0 * 0.7, tau0 * 0.3, -a0 * 0.3, tau0, bl1]
bi_lower = [-amp_bound, 0.1, -amp_bound, 0.1, bl1 - abs(r1_amp) * 2]
bi_upper = [0.0, tau0 * 100, 0.0, tau0 * 100, bl1 + abs(r1_amp)]
else:
bi_p0 = [a0 * 0.7, tau0 * 0.3, a0 * 0.3, tau0, bl1]
bi_lower = [0.0, 0.1, 0.0, 0.1, bl1 - abs(r1_amp)]
bi_upper = [amp_bound, tau0 * 100, amp_bound, tau0 * 100, bl1 + abs(r1_amp) * 2]
popt_bi, pcov_bi = curve_fit(_bi_exp, t_fit, y_fit, p0=bi_p0, bounds=(bi_lower, bi_upper), maxfev=4000)
# Fall back if covariance matrix cannot be estimated (degenerate fit).
if np.any(~np.isfinite(pcov_bi)):
raise ValueError("Infinite covariance: bi-exp degenerate")
a_f_fit, tau_f_fit, a_s_fit, tau_s_fit, _ = popt_bi
total_amp = abs(a_f_fit) + abs(a_s_fit)
if total_amp < 1e-12:
raise ValueError("Bi-exp amplitudes effectively zero")
# Amplitude-weighted dominant time constant (section 15.5).
tau_ms = (abs(a_f_fit) * tau_f_fit + abs(a_s_fit) * tau_s_fit) / total_amp
_fit_func = _bi_exp
_popt = popt_bi
except (RuntimeError, ValueError) as _bi_exc:
log.debug("PPR bi-exp failed (%s); falling back to mono-exp.", _bi_exc)
# ── Mono-exponential fallback ──
if _popt is None:
try:
if polarity == "negative":
popt_mono, _ = curve_fit(
_mono_exp,
t_fit,
y_fit,
p0=[-a0, tau0, bl1],
bounds=([-amp_bound, 0.1, bl1 - abs(r1_amp) * 2], [0.0, tau0 * 50, bl1 + abs(r1_amp)]),
maxfev=3000,
)
else:
popt_mono, _ = curve_fit(
_mono_exp,
t_fit,
y_fit,
p0=[a0, tau0, bl1],
bounds=([0.0, 0.1, bl1 - abs(r1_amp)], [amp_bound, tau0 * 50, bl1 + abs(r1_amp) * 2]),
maxfev=3000,
)
tau_ms = float(popt_mono[1])
_fit_func = _mono_exp
_popt = popt_mono
except (RuntimeError, ValueError) as _mono_exc:
log.debug("PPR mono-exp fallback failed (%s); tau_ms stays NaN.", _mono_exc)
out["decay_tau_ms"] = tau_ms
residual_at_stim2 = float(_fit_func(t_at_stim2_ms, *_popt)) - bl1
out["residual_at_stim2"] = residual_at_stim2
# Store fitted curve for visual overlay (private keys hidden from results table).
out["_ppr_fit_times"] = t_fit_abs.tolist()
out["_ppr_fit_values"] = [float(_fit_func(tv, *_popt)) for tv in t_fit]
except Exception as exc:
log.warning("PPR decay fit failed: %s", exc)
out["ppr_error"] = f"Decay fit failed: {exc}"
# --- R2 ---
bl2 = _local_baseline(stim2_onset_s)
r2_amp_raw, r2_peak_raw = _response_peak(stim2_onset_s, bl2)
out["r2_amplitude_raw"] = r2_amp_raw
# Compute the corrected R2 amplitude measured from bl1 (the true resting
# baseline before any stimulation), not from bl2 (the local baseline just
# before stim2 which may be contaminated by the R1 decay tail).
#
# Scientific rationale (Zucker & Regehr 2002, Regehr 2012):
# The "raw" amplitude r2_amp_raw is measured from bl2. When the R1 decay
# has not fully returned to baseline by stim2, bl2 < bl1 (for inward/
# negative events) or bl2 > bl1 (for outward/positive events). Using bl2
# as reference therefore underestimates the true R2 amplitude. The
# corrected amplitude is obtained by using bl1 as the reference, which
# directly captures the contamination without relying on a potentially
# poor extrapolation of the decay fit.
#
# Derivation:
# r2_peak_raw = actual peak value (returned by _response_peak)
# Negative: r2_corrected = bl1 - r2_peak_raw
# Positive: r2_corrected = r2_peak_raw - bl1
#
# When residual is negligible (bl2 ≈ bl1): r2_corrected ≈ r2_amp_raw.
# When residual is significant: r2_corrected uses bl1 as reference.
if polarity == "negative":
r2_corrected = bl1 - r2_peak_raw
else:
r2_corrected = r2_peak_raw - bl1
out["r2_amplitude_corrected"] = float(r2_corrected)
if r1_amp > 0:
out["paired_pulse_ratio"] = float(r2_corrected) / r1_amp
return out
# ---------------------------------------------------------------------------
# Registry Wrapper
# ---------------------------------------------------------------------------
[docs]
@AnalysisRegistry.register(
name="optogenetic_sync",
label="Evoked Sync",
requires_secondary_channel={
"param_name": "ttl_data",
"label": "TTL / Stimulus Channel:",
"tooltip": "Select the digital/TTL or stimulus channel (optical or electrical).",
},
ui_params=[
{
"name": "ttl_threshold",
"type": "float",
"label": "TTL Threshold (V)",
"default": 2.5,
"min": -1e9,
"max": 1e9,
"decimals": 4,
"tooltip": "Voltage threshold to define stimulus ON state.",
},
{
"name": "response_window_ms",
"type": "float",
"label": "Response Window (ms)",
"default": 20.0,
"min": 0.0,
"max": 1e9,
"decimals": 2,
"tooltip": "Time window after stimulus onset to search for events.",
},
{
"name": "event_detection_type",
"type": "choice",
"label": "Event Type:",
"choices": ["Spikes", "Events (Threshold)", "Events (Template)"],
"default": "Spikes",
"tooltip": (
"Spikes: detect action potentials by threshold crossing.\n"
"Events (Threshold): detect synaptic events by adaptive prominence.\n"
"Events (Template): detect events by template/matched-filter."
),
},
{
"name": "spike_threshold",
"type": "float",
"label": "AP Threshold (mV)",
"default": 0.0,
"min": -1e9,
"max": 1e9,
"decimals": 2,
"tooltip": "Voltage threshold to detect action potentials.",
"visible_when": {"param": "event_detection_type", "value": "Spikes"},
},
{
"name": "event_threshold",
"type": "float",
"label": "Event Threshold (pA/mV)",
"default": 5.0,
"min": 0.0,
"max": 1e9,
"decimals": 4,
"tooltip": "Prominence threshold for event detection.",
"visible_when": {"param": "event_detection_type", "value": "Events (Threshold)"},
},
{
"name": "event_direction",
"type": "choice",
"label": "Event Direction:",
"choices": ["negative", "positive"],
"default": "negative",
"visible_when": {"param": "event_detection_type", "value": "Events (Threshold)"},
},
{
"name": "event_refractory_s",
"type": "float",
"label": "Refractory (s)",
"default": 0.002,
"min": 0.0,
"max": 10.0,
"decimals": 4,
"visible_when": {"param": "event_detection_type", "value": "Events (Threshold)"},
},
{
"name": "template_tau_rise_ms",
"type": "float",
"label": "Tau Rise (ms)",
"default": 0.5,
"min": 0.0,
"max": 1e9,
"decimals": 3,
"visible_when": {"param": "event_detection_type", "value": "Events (Template)"},
},
{
"name": "template_tau_decay_ms",
"type": "float",
"label": "Tau Decay (ms)",
"default": 5.0,
"min": 0.0,
"max": 1e9,
"decimals": 3,
"visible_when": {"param": "event_detection_type", "value": "Events (Template)"},
},
{
"name": "template_threshold_sd",
"type": "float",
"label": "Template Threshold (SD)",
"default": 4.0,
"min": 0.0,
"max": 1e9,
"decimals": 2,
"visible_when": {"param": "event_detection_type", "value": "Events (Template)"},
},
{
"name": "template_direction",
"type": "choice",
"label": "Template Direction:",
"choices": ["negative", "positive"],
"default": "negative",
"visible_when": {"param": "event_detection_type", "value": "Events (Template)"},
},
{
"name": "template_kernel_shape",
"label": "Template Kernel Shape:",
"type": "choice",
"choices": ["bi-exponential", "mono-exponential"],
"default": "bi-exponential",
"tooltip": (
"bi-exponential uses distinct tau_rise and tau_decay. " "mono-exponential uses only tau_decay."
),
"visible_when": {"param": "event_detection_type", "value": "Events (Template)"},
},
{
"name": "template_kernel_multipliers",
"label": "Template Multipliers:",
"type": "string",
"default": "1.0, 2.0, 3.0",
"tooltip": (
"Comma-separated tau_decay scaling factors for the kernel bank. "
"E.g. '1.0, 2.0, 3.0' (Cable theory predicts ~2-3x slowdown for distal inputs)."
),
"visible_when": {"param": "event_detection_type", "value": "Events (Template)"},
},
{
"name": "response_polarity",
"type": "choice",
"label": "Peak Polarity:",
"choices": ["max", "min", "abs"],
"default": "max",
"tooltip": "Direction to search for peak response voltage within the window.",
},
{
"name": "amplitude_window_ms",
"type": "float",
"label": "Amplitude Window (ms):",
"default": 100.0,
"min": 0.0,
"max": 10000.0,
"decimals": 1,
"tooltip": (
"Window (ms after stimulus onset) used to find the peak response amplitude. "
"Independent of the event-detection Response Window. Should be wide enough "
"to cover the full response (e.g. 100 ms for slow EPSPs/EPSCs)."
),
},
{
"name": "artifact_blanking_ms",
"type": "float",
"label": "Artifact Blanking (ms):",
"default": 1.0,
"min": 0.0,
"max": 50.0,
"decimals": 2,
"tooltip": "Data within this window after each stimulus onset are excluded from peak detection.",
},
{
"name": "reject_artifacts",
"label": "Reject Slope Artifacts",
"type": "bool",
"default": False,
"tooltip": (
"Detect and mask sharp slope-transients (e.g. electrical stimulation "
"artefacts) before event detection. Only applied when Event Type is "
"'Events (Threshold)' or 'Events (Template)'."
),
},
{
"name": "artifact_slope_threshold",
"label": "Artifact Slope Thresh:",
"type": "float",
"default": 20.0,
"min": 0.0,
"max": 1e6,
"decimals": 1,
"tooltip": "Slope (units/ms) above which a transient is classified as an artefact.",
"visible_when": {"param": "reject_artifacts", "value": True},
},
{
"name": "artifact_padding_ms",
"label": "Artifact Padding (ms):",
"type": "float",
"default": 2.0,
"min": 0.0,
"max": 100.0,
"decimals": 1,
"tooltip": "Samples within this window around each detected artefact are also masked.",
"visible_when": {"param": "reject_artifacts", "value": True},
},
],
plots=[
{"name": "Trace", "type": "trace", "show_events": True},
{"type": "vlines", "data": "stimulus_onsets"},
{"type": "markers", "x": "_peak_times", "y": "_peak_amps", "symbol": "d"},
],
)
def run_opto_sync_wrapper( # noqa: C901
data: np.ndarray, time: np.ndarray, sampling_rate: float, **kwargs
) -> Dict[str, Any]:
"""
Wrapper for optogenetic synchronization analysis.
Correlates TTL/optical stimulus pulses with detected events.
"""
ttl_threshold = kwargs.get("ttl_threshold", 2.5)
response_window_ms = kwargs.get("response_window_ms", 20.0)
amplitude_window_ms = float(kwargs.get("amplitude_window_ms", 100.0))
event_detection_type = kwargs.get("event_detection_type", "Spikes")
response_polarity = kwargs.get("response_polarity", "max")
artifact_blanking_ms = float(kwargs.get("artifact_blanking_ms", 1.0))
# Build slope-based artifact mask for event detection types if requested.
_reject_artifacts = kwargs.get("reject_artifacts", False)
_artifact_mask = None
if _reject_artifacts and event_detection_type in ("Events (Threshold)", "Events (Template)"):
_slope_thresh = kwargs.get("artifact_slope_threshold", 20.0)
_padding_ms = kwargs.get("artifact_padding_ms", 2.0)
_artifact_mask = find_artifact_windows(data, sampling_rate, _slope_thresh, _padding_ms)
ap_times = kwargs.get("action_potential_times", None)
if ap_times is None:
if event_detection_type == "Spikes":
ap_threshold = kwargs.get("spike_threshold", 0.0)
refractory_samples = max(1, int(0.002 * sampling_rate))
spike_result = detect_spikes_threshold(
data, time, threshold=ap_threshold, refractory_samples=refractory_samples
)
has_spikes = spike_result.spike_indices is not None and len(spike_result.spike_indices) > 0
ap_times = time[spike_result.spike_indices] if has_spikes else np.array([])
elif event_detection_type == "Events (Threshold)":
ev_threshold = kwargs.get("event_threshold", 5.0)
direction = kwargs.get("event_direction", "negative")
refractory = kwargs.get("event_refractory_s", 0.002)
ev_result = detect_events_threshold(
data,
time,
threshold=ev_threshold,
polarity=direction,
refractory_period=refractory,
artifact_mask=_artifact_mask,
)
if ev_result.is_valid and ev_result.event_times is not None and len(ev_result.event_times) > 0:
ap_times = ev_result.event_times
else:
ap_times = np.array([])
elif event_detection_type == "Events (Template)":
tau_rise = kwargs.get("template_tau_rise_ms", 0.5) / 1000.0
tau_decay = kwargs.get("template_tau_decay_ms", 5.0) / 1000.0
threshold_sd = kwargs.get("template_threshold_sd", 4.0)
direction = kwargs.get("template_direction", "negative")
_raw_km = kwargs.get("template_kernel_multipliers", "1.0, 2.0, 3.0")
try:
_km = [float(x.strip()) for x in _raw_km.split(",") if x.strip()]
if not _km:
raise ValueError("empty")
except (ValueError, AttributeError):
_km = [1.0, 2.0, 3.0]
ev_result = detect_events_template(
data=data,
sampling_rate=sampling_rate,
threshold_std=threshold_sd,
tau_rise=tau_rise,
tau_decay=tau_decay,
polarity=direction,
time=time,
artifact_mask=_artifact_mask,
kernel_multipliers=_km,
kernel_shape=kwargs.get("template_kernel_shape", "bi-exponential"),
)
if ev_result.is_valid and ev_result.event_times is not None and len(ev_result.event_times) > 0:
ap_times = ev_result.event_times
else:
ap_times = np.array([])
else:
ap_times = np.array([])
log.warning("Unknown event_detection_type '%s'; defaulting to no events.", event_detection_type)
ttl_data = kwargs.get("ttl_data", None)
if ttl_data is None:
log.debug("No TTL data provided; using voltage trace as fallback for TTL edge detection.")
ttl_data = data
result = calculate_optogenetic_sync(
ttl_data=ttl_data,
action_potential_times=ap_times,
time=time,
ttl_threshold=ttl_threshold,
response_window_ms=response_window_ms,
)
if not result.is_valid:
return {"module_used": "evoked_responses", "metrics": {"error": result.error_message}}
# Find peak response voltage within each TTL stimulus window.
# Uses amplitude_window_ms (independent, default 100 ms) so that the
# peak search always covers the full response regardless of the narrower
# event-detection Response Window. The first artifact_blanking_ms after
# each stimulus onset are skipped to exclude the stimulus artefact.
_peak_times: List[float] = []
_peak_amps: List[float] = []
_amp_window_s = amplitude_window_ms / 1000.0
_blank_s = artifact_blanking_ms / 1000.0
if result.stimulus_onsets is not None and len(data) > 0:
for _onset in result.stimulus_onsets:
_idx_start = int(np.searchsorted(time, _onset + _blank_s, side="left"))
_idx_end = int(np.searchsorted(time, _onset + _amp_window_s, side="right"))
_idx_start = max(0, min(_idx_start, len(data) - 1))
_idx_end = max(_idx_start + 1, min(_idx_end, len(data)))
_window_data = data[_idx_start:_idx_end]
if len(_window_data) > 0:
if response_polarity == "min":
_local_idx = int(np.argmin(_window_data))
elif response_polarity == "abs":
_local_idx = int(np.argmax(np.abs(_window_data)))
else:
_local_idx = int(np.argmax(_window_data))
_abs_idx = _idx_start + _local_idx
_peak_times.append(float(time[_abs_idx]))
_peak_amps.append(float(data[_abs_idx]))
# Response probability as a percentage for human-readable reporting.
resp_prob_pct = round(result.response_probability * 100.0, 2) if result.response_probability is not None else np.nan
return {
"module_used": "evoked_responses",
"metrics": {
"optical_latency_ms": result.optical_latency_ms,
"response_probability": result.response_probability,
"response_probability_pct": resp_prob_pct,
"spike_jitter_ms": result.spike_jitter_ms,
"stimulus_count": result.stimulus_count,
"Success Count": result.success_count,
"Failure Count": result.failure_count,
"event_count": len(ap_times),
"event_times": ap_times.tolist() if hasattr(ap_times, "tolist") else list(ap_times),
"stimulus_onsets": (result.stimulus_onsets.tolist() if result.stimulus_onsets is not None else []),
"_peak_times": _peak_times,
"_peak_amps": _peak_amps,
},
}
# ---------------------------------------------------------------------------
# PPR Registry Wrapper
# ---------------------------------------------------------------------------
[docs]
@AnalysisRegistry.register(
"paired_pulse_ratio",
label="Paired-Pulse Ratio",
requires_secondary_channel={
"param_name": "ttl_data",
"label": "TTL / Stimulus Channel:",
"tooltip": "Optional TTL channel. When 'Detect Stim from TTL' is enabled, "
"stimulus times are read from this channel instead of the manual spinboxes.",
},
plots=[
{"name": "Trace", "type": "trace"},
{"type": "vlines", "data": "_stim_onsets"},
{
"type": "trace_overlay",
"start_time": "_baseline_start_s",
"end_time": "_baseline_end_s",
"color": "#00cfff",
"width": 3,
"opacity": 50,
},
{
"type": "event_fit_overlay",
"times_key": "_ppr_fit_times",
"values_key": "_ppr_fit_values",
"color": "#ff9900",
"width": 2,
"opacity": 85,
},
{"type": "markers", "x": "_peak_times", "y": "_peak_amps", "symbol": "d", "color": "#ff6600"},
],
ui_params=[
{
"name": "use_ttl",
"label": "Detect Stim from TTL:",
"type": "bool",
"default": False,
"tooltip": "When enabled, Stim 1 and Stim 2 onsets are detected automatically "
"from the TTL channel. Select the TTL channel in the secondary-channel "
"dropdown above.",
},
{
"name": "ttl_threshold",
"label": "TTL Threshold (V):",
"type": "float",
"default": 2.5,
"min": -100.0,
"max": 100.0,
"decimals": 3,
"tooltip": "Binarisation threshold for TTL edge detection.",
"visible_when": {"param": "use_ttl", "value": True},
},
{
"name": "stim1_onset_s",
"label": "Stim 1 Onset (s):",
"type": "float",
"default": 0.1,
"min": 0.0,
"max": 1e9,
"decimals": 4,
"visible_when": {"param": "use_ttl", "value": False},
},
{
"name": "stim2_onset_s",
"label": "Stim 2 Onset (s):",
"type": "float",
"default": 0.2,
"min": 0.0,
"max": 1e9,
"decimals": 4,
"visible_when": {"param": "use_ttl", "value": False},
},
{
"name": "polarity",
"label": "Event Polarity:",
"type": "choice",
"choices": ["negative", "positive"],
"default": "negative",
},
{
"name": "response_window_ms",
"label": "Response Window (ms):",
"type": "float",
"default": 20.0,
"min": 1.0,
"max": 500.0,
"decimals": 1,
},
{
"name": "baseline_window_ms",
"label": "Baseline Window (ms):",
"type": "float",
"default": 5.0,
"min": 1.0,
"max": 100.0,
"decimals": 1,
},
{
"name": "fit_decay_from_ms",
"label": "Decay Fit Start (ms):",
"type": "float",
"default": 5.0,
"min": 0.0,
"max": 100.0,
"decimals": 1,
"tooltip": "Offset from Stim1 onset to begin fitting the decay (skip initial transient).",
},
{
"name": "fit_decay_window_ms",
"label": "Decay Fit Window (ms):",
"type": "float",
"default": 30.0,
"min": 5.0,
"max": 500.0,
"decimals": 1,
},
{
"name": "artifact_blanking_ms",
"label": "Artifact Blanking (ms):",
"type": "float",
"default": 1.0,
"min": 0.0,
"max": 50.0,
"decimals": 2,
"tooltip": "Data within this window after each stimulus onset are excluded from peak detection.",
},
],
)
def run_ppr_wrapper(
data: np.ndarray,
time: np.ndarray,
sampling_rate: float,
**kwargs,
) -> Dict[str, Any]:
"""Wrapper for Paired-Pulse Ratio analysis with optional TTL-based onset detection."""
use_ttl = bool(kwargs.get("use_ttl", False))
ttl_threshold = float(kwargs.get("ttl_threshold", 2.5))
stim1_onset_s = float(kwargs.get("stim1_onset_s", 0.1))
stim2_onset_s = float(kwargs.get("stim2_onset_s", 0.2))
# Auto-detect stimulus times from TTL channel when requested.
if use_ttl:
ttl_data = kwargs.get("ttl_data", None)
if ttl_data is not None and len(ttl_data) > 0:
onsets, _ = extract_ttl_epochs(ttl_data, time, ttl_threshold)
if onsets is not None and len(onsets) >= 2:
stim1_onset_s = float(onsets[0])
stim2_onset_s = float(onsets[1])
log.debug("PPR: TTL-detected stim1=%.4f s, stim2=%.4f s", stim1_onset_s, stim2_onset_s)
elif onsets is not None and len(onsets) == 1:
stim1_onset_s = float(onsets[0])
log.warning("PPR: TTL detected only one onset; stim2 retains manual value %.4f s", stim2_onset_s)
else:
log.warning("PPR: use_ttl=True but no TTL data provided; using manual onsets.")
polarity = kwargs.get("polarity", "negative")
response_window_ms = float(kwargs.get("response_window_ms", 20.0))
baseline_window_ms = float(kwargs.get("baseline_window_ms", 5.0))
fit_decay_from_ms = float(kwargs.get("fit_decay_from_ms", 5.0))
fit_decay_window_ms = float(kwargs.get("fit_decay_window_ms", 30.0))
artifact_blanking_ms = float(kwargs.get("artifact_blanking_ms", 1.0))
result = calculate_paired_pulse_ratio(
data=data,
time=time,
stim1_onset_s=stim1_onset_s,
stim2_onset_s=stim2_onset_s,
response_window_ms=response_window_ms,
baseline_window_ms=baseline_window_ms,
fit_decay_from_ms=fit_decay_from_ms,
fit_decay_window_ms=fit_decay_window_ms,
polarity=polarity,
artifact_blanking_ms=artifact_blanking_ms,
)
# Compute peak positions for scatter overlays.
blank_s = artifact_blanking_ms / 1000.0
win_s = response_window_ms / 1000.0
r1_peak_t, r1_peak_v = _peak_pos_s(data, time, stim1_onset_s, polarity, blank_s, win_s)
r2_peak_t, r2_peak_v = _peak_pos_s(data, time, stim2_onset_s, polarity, blank_s, win_s)
return {
"module_used": "evoked_responses",
"metrics": {
"r1_amplitude": result["r1_amplitude"],
"r2_amplitude_raw": result["r2_amplitude_raw"],
"r2_amplitude_corrected": result["r2_amplitude_corrected"],
"residual_at_stim2": result["residual_at_stim2"],
"paired_pulse_ratio": result["paired_pulse_ratio"],
"decay_tau_ms": result["decay_tau_ms"],
"ppr_error": result["ppr_error"],
"stim1_onset_used_s": stim1_onset_s,
"stim2_onset_used_s": stim2_onset_s,
"_stim_onsets": [stim1_onset_s, stim2_onset_s],
"_baseline_start_s": stim1_onset_s - baseline_window_ms / 1000.0,
"_baseline_end_s": stim1_onset_s,
"_ppr_fit_times": result.get("_ppr_fit_times"),
"_ppr_fit_values": result.get("_ppr_fit_values"),
"_peak_times": [r1_peak_t, r2_peak_t],
"_peak_amps": [r1_peak_v, r2_peak_v],
},
}
# ---------------------------------------------------------------------------
# Stimulus Train STP
# ---------------------------------------------------------------------------
[docs]
def calculate_stimulus_train_stp( # noqa: C901
data: np.ndarray,
time: np.ndarray,
stim_onsets: np.ndarray,
polarity: str = "negative",
response_window_ms: float = 20.0,
baseline_window_ms: float = 5.0,
artifact_blanking_ms: float = 1.0,
) -> Dict[str, Any]:
"""Compute short-term plasticity (STP) amplitudes for a stimulus train.
For each stimulus onset the function measures a baseline immediately
preceding the stimulus, then finds the peak response in a post-stimulus
window (after artifact blanking). Amplitudes are normalised to R1 to
yield the STP profile.
Args:
data: 1-D voltage or current trace.
time: 1-D time vector (seconds, same length as data).
stim_onsets: Stimulus onset times in seconds, ordered chronologically.
polarity: ``"negative"`` for inward/hyperpolarising events,
``"positive"`` for outward/depolarising events.
response_window_ms: Duration of the post-stimulus peak-search window
in milliseconds.
baseline_window_ms: Duration of the pre-stimulus baseline window in
milliseconds.
artifact_blanking_ms: Data within this interval after each onset are
excluded from peak detection.
Returns:
Dictionary with keys ``amplitudes``, ``amplitudes_norm``,
``pulse_numbers``, ``stim_onsets`` and descriptive metric keys.
"""
if data.size < 2 or time.shape != data.shape:
return {"stp_error": "Invalid data or time array"}
blank_s = artifact_blanking_ms / 1000.0
win_s = response_window_ms / 1000.0
bl_s = baseline_window_ms / 1000.0
def _idx(t: float) -> int:
return int(np.searchsorted(time, t))
def _baseline(onset: float) -> float:
i0 = _idx(max(onset - bl_s, float(time[0])))
i1 = max(_idx(onset), i0 + 1)
seg = data[i0:i1]
return float(np.mean(seg)) if seg.size > 0 else float(data[_idx(onset)])
def _amplitude(onset: float, baseline: float) -> float:
i_start = _idx(onset + blank_s)
i_end = min(_idx(onset + win_s) + 1, len(data))
if i_end <= i_start:
return 0.0
seg = data[i_start:i_end]
if polarity == "negative":
return float(baseline - np.min(seg))
return float(np.max(seg) - baseline)
amplitudes: List[float] = []
peak_times: List[float] = []
peak_values: List[float] = []
for onset in stim_onsets:
bl = _baseline(onset)
amplitudes.append(_amplitude(onset, bl))
pt, pv = _peak_pos_s(data, time, float(onset), polarity, blank_s, win_s)
peak_times.append(pt)
peak_values.append(pv)
n = len(amplitudes)
pulse_numbers = list(range(1, n + 1))
r1 = amplitudes[0] if amplitudes else 1.0
amplitudes_norm = [a / r1 if r1 != 0.0 else float("nan") for a in amplitudes]
ratios: Dict[str, Any] = {}
for i in range(1, n):
ratios[f"R{i + 1}/R1"] = round(amplitudes_norm[i], 4)
stp_type = "none"
if n >= 2:
stp_type = "facilitation" if amplitudes[1] > amplitudes[0] else "depression"
return {
"pulse_count": n,
"r1_amplitude": round(amplitudes[0], 4) if amplitudes else None,
"stp_type": stp_type,
**ratios,
"amplitudes": [round(a, 4) for a in amplitudes],
"amplitudes_norm": amplitudes_norm,
"pulse_numbers": pulse_numbers,
"_stim_onsets": stim_onsets.tolist(),
"_peak_times": peak_times,
"_peak_amps": peak_values,
}
[docs]
@AnalysisRegistry.register(
"stimulus_train_stp",
label="Stimulus Train (STP)",
requires_secondary_channel={
"param_name": "ttl_data",
"label": "TTL / Stimulus Channel:",
"tooltip": "Select the TTL/trigger channel to auto-detect stimulus times. "
"Leave unset to use the manual frequency and start-time parameters.",
},
ui_params=[
{
"name": "use_ttl",
"label": "Detect Stim from TTL:",
"type": "bool",
"default": True,
"tooltip": "When enabled, stimulus times are detected from the TTL channel. "
"When disabled, times are generated from the frequency and start-time "
"parameters below.",
},
{
"name": "ttl_threshold",
"label": "TTL Threshold (V):",
"type": "float",
"default": 2.5,
"min": -100.0,
"max": 100.0,
"decimals": 3,
"visible_when": {"param": "use_ttl", "value": True},
},
{
"name": "stim_start_s",
"label": "First Stim Onset (s):",
"type": "float",
"default": 0.1,
"min": 0.0,
"max": 1e9,
"decimals": 4,
"tooltip": "Time of the first stimulus pulse. Used when TTL detection is disabled.",
"visible_when": {"param": "use_ttl", "value": False},
},
{
"name": "stim_frequency_hz",
"label": "Stim Frequency (Hz):",
"type": "float",
"default": 10.0,
"min": 0.1,
"max": 1000.0,
"decimals": 1,
"tooltip": "Stimulation frequency in Hz. Used when TTL detection is disabled.",
"visible_when": {"param": "use_ttl", "value": False},
},
{
"name": "n_pulses",
"label": "Number of Pulses:",
"type": "int",
"default": 5,
"min": 2,
"max": 100,
"tooltip": "Maximum number of stimulus pulses to include.",
},
{
"name": "polarity",
"label": "Event Polarity:",
"type": "choice",
"choices": ["negative", "positive"],
"default": "negative",
},
{
"name": "response_window_ms",
"label": "Response Window (ms):",
"type": "float",
"default": 20.0,
"min": 1.0,
"max": 500.0,
"decimals": 1,
},
{
"name": "baseline_window_ms",
"label": "Baseline Window (ms):",
"type": "float",
"default": 5.0,
"min": 1.0,
"max": 100.0,
"decimals": 1,
},
{
"name": "artifact_blanking_ms",
"label": "Artifact Blanking (ms):",
"type": "float",
"default": 1.0,
"min": 0.0,
"max": 50.0,
"decimals": 2,
"tooltip": "Data within this window after each stimulus onset are excluded from " "peak detection.",
},
],
plots=[
{"name": "Trace", "type": "trace"},
{"type": "vlines", "data": "_stim_onsets"},
{"type": "markers", "x": "_peak_times", "y": "_peak_amps", "symbol": "d", "color": "#ff6600"},
{
"type": "popup_xy",
"title": "STP Profile",
"x": "pulse_numbers",
"y": "amplitudes_norm",
"x_label": "Pulse Number",
"y_label": "Normalised Amplitude (R_n / R_1)",
},
],
)
def run_stimulus_train_stp_wrapper(
data: np.ndarray,
time: np.ndarray,
sampling_rate: float,
**kwargs,
) -> Dict[str, Any]:
"""Wrapper for Stimulus Train STP analysis.
Stimulus times are either detected from an optional TTL/trigger channel or
generated from a user-supplied frequency and start time. For each pulse the
wrapper measures a baseline-subtracted peak amplitude and normalises the
series to R1 to produce the STP profile.
"""
use_ttl = bool(kwargs.get("use_ttl", True))
ttl_threshold = float(kwargs.get("ttl_threshold", 2.5))
stim_start_s = float(kwargs.get("stim_start_s", 0.1))
stim_frequency_hz = float(kwargs.get("stim_frequency_hz", 10.0))
n_pulses = int(kwargs.get("n_pulses", 5))
polarity = str(kwargs.get("polarity", "negative"))
response_window_ms = float(kwargs.get("response_window_ms", 20.0))
baseline_window_ms = float(kwargs.get("baseline_window_ms", 5.0))
artifact_blanking_ms = float(kwargs.get("artifact_blanking_ms", 1.0))
# --- Determine stimulus onsets ---
stim_onsets: Optional[np.ndarray] = None
if use_ttl:
ttl_data = kwargs.get("ttl_data", None)
if ttl_data is not None and len(ttl_data) > 0:
detected, _ = extract_ttl_epochs(ttl_data, time, ttl_threshold)
if detected is not None and len(detected) > 0:
stim_onsets = detected[:n_pulses]
log.debug("STP: TTL detected %d onsets, using first %d", len(detected), len(stim_onsets))
if stim_onsets is None or len(stim_onsets) == 0:
# Fall back to manual frequency + start time.
if use_ttl:
log.warning("STP: TTL detection yielded no onsets; falling back to manual frequency parameters.")
if stim_frequency_hz <= 0.0:
return {
"module_used": "evoked_responses",
"metrics": {"stp_error": "Stimulus frequency must be > 0 Hz"},
}
isi = 1.0 / stim_frequency_hz
stim_onsets = np.array([stim_start_s + i * isi for i in range(n_pulses)])
# Clip to recording duration.
stim_onsets = stim_onsets[stim_onsets < float(time[-1])]
if len(stim_onsets) == 0:
return {
"module_used": "evoked_responses",
"metrics": {"stp_error": "No stimulus onsets lie within the recording duration"},
}
result = calculate_stimulus_train_stp(
data=data,
time=time,
stim_onsets=stim_onsets,
polarity=polarity,
response_window_ms=response_window_ms,
baseline_window_ms=baseline_window_ms,
artifact_blanking_ms=artifact_blanking_ms,
)
if "stp_error" in result:
return {"module_used": "evoked_responses", "metrics": result}
return {"module_used": "evoked_responses", "metrics": result}
# ---------------------------------------------------------------------------
# Module-level tab aggregator
# ---------------------------------------------------------------------------
[docs]
@AnalysisRegistry.register(
"evoked_responses",
label="Evoked Responses",
requires_secondary_channel={
"param_name": "ttl_data",
"label": "TTL / Stimulus Channel:",
"tooltip": "Select the digital/TTL or stimulus channel (optical or electrical).",
},
method_selector={
"Evoked Sync": "optogenetic_sync",
"Paired-Pulse Ratio": "paired_pulse_ratio",
"Stimulus Train (STP)": "stimulus_train_stp",
},
ui_params=[],
plots=[],
)
def evoked_responses_module(**kwargs):
"""Module-level aggregator tab for evoked-response analyses."""
return {}