"""
Batch Analysis Engine for Synaptipy.
Handles processing multiple files and aggregating results using a flexible registry-based pipeline.
The engine uses a registry-based architecture where analysis functions register
themselves via decorators, and the pipeline configuration defines what analyses
to run on which data scopes.
Output Design Principles
------------------------
1. Every row is fully traceable to its source (file, channel, trial, analysis).
2. Metadata columns appear first; analysis results in the middle; internal/debug last.
3. Scalar results live in their own columns; array values are summarised for tabular
compatibility (Excel, Origin, R, MATLAB) and the raw arrays are kept under
private ``_``-prefixed keys that are stripped during CSV export.
4. Channel physical units are always recorded so downstream scripts can auto-label axes.
5. Recording-level metadata (protocol, duration, session time) is propagated when available.
Author: Anzal K Shahul <anzal.ks@gmail.com>
"""
import gc
import logging
import multiprocessing
import traceback # Added for stack trace logging
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
# Import analysis package to trigger all registrations
import Synaptipy.core.analysis # noqa: F401 - Import triggers all registrations
from Synaptipy.core.analysis.cross_file_utils import average_padded_trials as get_cross_file_average
from Synaptipy.core.analysis.registry import AnalysisRegistry
from Synaptipy.core.data_model import Recording
from Synaptipy.infrastructure.file_readers import NeoAdapter
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Column ordering constants — metadata first, results middle, debug last
# ---------------------------------------------------------------------------
_METADATA_COLUMNS_ORDER = [
"subject_id",
"cell_id",
"file_name",
"file_path",
"protocol",
"recording_duration_s",
"channel",
"channel_units",
"analysis",
"scope",
"trial_index",
"trial_count",
"sampling_rate",
]
_TRAILING_COLUMNS = [
"batch_timestamp",
"error",
"debug_trace",
]
# Human-readable aliases for result keys that lack biological context.
# Only applied as *additional* columns; originals are preserved for scripting.
_HUMAN_READABLE_ALIASES: Dict[str, str] = {
"cv": "coeff_of_variation",
"cv2": "local_cv2_holt",
"lv": "local_variation_shinomoto",
"fi_slope": "fi_gain_hz_per_pa",
"fi_r_squared": "fi_fit_r_squared",
"iv_r_squared": "iv_fit_r_squared",
}
[docs]
class BatchAnalysisEngine:
"""
Engine for running analysis across multiple files/recordings using a flexible pipeline.
The engine uses a registry-based architecture where analysis functions register
themselves via decorators, and the pipeline configuration defines what analyses
to run on which data scopes.
Example::
engine = BatchAnalysisEngine()
files = [Path("file1.abf"), Path("file2.abf")]
pipeline = [
{
'analysis': 'spike_detection',
'scope': 'all_trials',
'params': {'threshold': -15.0, 'refractory_ms': 2.0}
},
{
'analysis': 'rmp_analysis',
'scope': 'average',
'params': {'baseline_start': 0.0, 'baseline_end': 0.1}
}
]
results_df = engine.run_batch(files, pipeline)
"""
def __init__(self, neo_adapter: Optional[NeoAdapter] = None, max_workers: int = 1):
"""
Initialize the batch analysis engine.
Args:
neo_adapter: Optional NeoAdapter instance. If None, creates a new one.
max_workers: Number of parallel worker processes for :meth:`run_batch`.
1 (default) means fully sequential execution.
Values > 1 enable :class:`~concurrent.futures.ProcessPoolExecutor`
parallelism. Pass ``-1`` to use all available CPU cores.
"""
self.neo_adapter = neo_adapter if neo_adapter else NeoAdapter()
self._cancelled = False
cpu_count = multiprocessing.cpu_count()
if max_workers < 0:
self.max_workers: int = cpu_count
else:
self.max_workers = max(1, int(max_workers))
[docs]
def cancel(self):
"""Request cancellation of the current batch run."""
self._cancelled = True
log.debug("Batch analysis cancellation requested.")
[docs]
@staticmethod
def list_available_analyses() -> List[str]:
"""
Get a list of all registered analysis function names.
Returns:
List of available analysis names.
"""
return AnalysisRegistry.list_registered()
[docs]
@staticmethod
def get_analysis_info(name: str) -> Optional[Dict[str, Any]]:
"""
Get information about a registered analysis function.
Args:
name: The registered name of the analysis function.
Returns:
Dictionary with function info (docstring, etc.) or None if not found.
"""
func = AnalysisRegistry.get_function(name)
if func is None:
return None
return {
"name": name,
"docstring": func.__doc__ or "No documentation available.",
"module": func.__module__,
}
# ------------------------------------------------------------------
# Output post-processing helpers
# ------------------------------------------------------------------
@staticmethod
def _sanitise_value(key: str, value: Any) -> Tuple[Any, Optional[Tuple[str, Any]]]:
"""Sanitise a single result value for export.
Returns:
Tuple of (replacement_value, optional (stash_key, stash_value)).
"""
if isinstance(value, np.ndarray):
return BatchAnalysisEngine._sanitise_ndarray(key, value)
if isinstance(value, list) and len(value) > 5:
return BatchAnalysisEngine._sanitise_long_list(key, value)
if not isinstance(value, (int, float, str, bool, type(None))):
return f"{type(value).__name__}", (f"_{key}_obj", value)
return value, None
@staticmethod
def _sanitise_ndarray(key: str, value: np.ndarray) -> Tuple[Any, Optional[Tuple[str, Any]]]:
"""Summarise numpy arrays for CSV-friendly output."""
if value.size <= 5:
return value.tolist(), None
summary = f"n={value.size}"
if np.issubdtype(value.dtype, np.floating):
summary = (
f"n={value.size}, "
f"mean={np.nanmean(value):.4g}, "
f"min={np.nanmin(value):.4g}, "
f"max={np.nanmax(value):.4g}"
)
return summary, (f"_{key}_raw", value)
@staticmethod
def _sanitise_long_list(key: str, value: list) -> Tuple[Any, Optional[Tuple[str, Any]]]:
"""Summarise long lists for CSV-friendly output."""
try:
arr = np.asarray(value, dtype=float)
summary = (
f"n={arr.size}, "
f"mean={np.nanmean(arr):.4g}, "
f"min={np.nanmin(arr):.4g}, "
f"max={np.nanmax(arr):.4g}"
)
return summary, (f"_{key}_raw", arr)
except (ValueError, TypeError):
return f"[{len(value)} items]", None
# Keys whose arrays represent discrete events (spike times, PSP amplitudes,
# etc.) and should be preserved verbatim inside ``_raw_arrays`` for downstream
# plotting and long-format CSV export. Full continuous trace arrays (fitting
# curves, 10 kHz voltage traces) must NOT appear here to avoid memory spikes.
_EVENT_ARRAY_KEYS: frozenset = frozenset(
{
"event_times",
"event_amplitudes",
"event_iei",
"spike_times",
"spike_amplitudes",
"spike_counts",
"frequencies",
"current_steps",
"adaptation_ratios",
"broadening_indices",
"isi_array",
"event_rise_times",
"event_decay_times",
"event_charges",
}
)
@staticmethod
def _sanitise_result_for_export(result: Dict[str, Any]) -> Dict[str, Any]:
"""
Make a single result row export-friendly.
Dual-representation architecture
---------------------------------
1. Public array keys whose names appear in ``_EVENT_ARRAY_KEYS`` produce
both a human-readable summary string (kept in the main dict so UI
tables render correctly) *and* a copy of the raw NumPy array stored
under ``result["_raw_arrays"][key]``. Only discrete-event arrays
(spike times, PSP amplitudes, …) are stored this way; continuous
10 kHz trace buffers are excluded to avoid memory spikes.
2. All other non-scalar values are summarised as strings and the raw
object moved to a ``_``-prefixed key (existing behaviour).
3. Human-readable aliases are added for cryptic algorithm names.
4. Private ``_``-prefixed keys are left untouched.
Returns:
Cleaned result dict (modified in-place for efficiency).
"""
raw_arrays: Dict[str, Any] = result.get("_raw_arrays", {}) # may already exist
keys_to_add: Dict[str, Any] = {}
for key, value in list(result.items()):
if key.startswith("_"):
continue
new_value, stash = BatchAnalysisEngine._sanitise_value(key, value)
result[key] = new_value
# For discrete-event arrays: also store raw data under _raw_arrays.
if key in BatchAnalysisEngine._EVENT_ARRAY_KEYS and isinstance(value, (np.ndarray, list)):
arr = np.asarray(value) if isinstance(value, list) else value
# Only store compact discrete-event arrays (<= 50 000 elements)
# to guard against accidentally persisting continuous traces.
if arr.ndim <= 2 and arr.size <= 50_000:
raw_arrays[key] = arr
if stash is not None:
keys_to_add[stash[0]] = stash[1]
if raw_arrays:
result["_raw_arrays"] = raw_arrays
result.update(keys_to_add)
# Add human-readable aliases for cryptic keys
for orig_key, alias in _HUMAN_READABLE_ALIASES.items():
if orig_key in result and alias not in result:
result[alias] = result[orig_key]
return result
@staticmethod
def _recording_metadata(recording: "Recording") -> Dict[str, Any]:
"""Extract recording-level metadata for result rows.
Includes ``subject_id`` and ``cell_id`` when set on the Recording so
that downstream hierarchical mixed-effects analyses can distinguish
between-subject (N) from within-subject (n) observations.
"""
meta: Dict[str, Any] = {}
if recording is None:
return meta
if hasattr(recording, "protocol_name") and recording.protocol_name:
meta["protocol"] = recording.protocol_name
if hasattr(recording, "duration") and recording.duration is not None:
meta["recording_duration_s"] = round(float(recording.duration), 4)
if hasattr(recording, "subject_id"):
meta["subject_id"] = recording.subject_id
if hasattr(recording, "cell_id"):
meta["cell_id"] = recording.cell_id
return meta
@staticmethod
def _order_columns(df: pd.DataFrame) -> pd.DataFrame:
"""Reorder DataFrame columns: metadata → results → trailing/debug."""
if df.empty:
return df
all_cols = list(df.columns)
# 1. Leading metadata columns (in defined order)
leading = [c for c in _METADATA_COLUMNS_ORDER if c in all_cols]
# 2. Trailing debug/internal columns
trailing = [c for c in _TRAILING_COLUMNS if c in all_cols]
# 3. Private columns (underscore-prefixed)
private = sorted(c for c in all_cols if c.startswith("_"))
# 4. Everything else = result columns, alphabetically
used = set(leading) | set(trailing) | set(private)
results = sorted(c for c in all_cols if c not in used)
ordered = leading + results + trailing + private
return df[[c for c in ordered if c in all_cols]]
@staticmethod
def _append_batch_error_log(file_name: str, file_path_str: str, exc: Exception) -> None:
"""Append a one-line error entry to ``~/.synaptipy/logs/batch_errors.log``.
Writing errors to a dedicated log file ensures a 100-file batch is never
aborted by a single corrupted recording. Each line is ISO-8601 timestamped
so the analyst can correlate entries with their batch run.
Args:
file_name: Base name of the failed file.
file_path_str: Full path string for the failed file.
exc: The exception that caused the failure.
"""
try:
log_dir = Path.home() / ".synaptipy" / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
error_log_path = log_dir / "batch_errors.log"
timestamp = datetime.now().isoformat(timespec="seconds")
entry = f"{timestamp} | {file_name} | {file_path_str} | {type(exc).__name__}: {exc}\n"
with open(error_log_path, "a", encoding="utf-8") as fh:
fh.write(entry)
except Exception as write_exc: # noqa: BLE001
log.warning("Could not write to batch_errors.log: %s", write_exc)
# ------------------------------------------------------------------
# Parallel execution helpers
# ------------------------------------------------------------------
def _run_batch_parallel( # noqa: C901
self,
files: List[Union[Path, "Recording"]],
pipeline_config: List[Dict[str, Any]],
progress_callback: Optional[Callable[[int, int, str], None]],
channel_filter: Optional[List[str]],
) -> pd.DataFrame:
"""Distribute file-level processing across :attr:`max_workers` worker processes.
Each worker process receives a single file path (in-memory Recording objects
are serialised via pickle), imports the full analysis package to populate the
registry, and returns a list of result-row dicts. Progress signals are emitted
through the optional *progress_callback* as each future completes.
OOM safety: every worker calls ``gc.collect()`` after processing its file.
"""
total_files = len(files)
batch_start_time = datetime.now()
# Separate paths from pre-loaded Recording objects.
# Pre-loaded Recording objects are processed sequentially (pickle cost not worth it).
path_tasks: List[Tuple[int, Path]] = []
inline_recordings: List[Tuple[int, Any]] = []
for idx, item in enumerate(files):
if isinstance(item, (str, Path)):
path_tasks.append((idx, Path(item)))
else:
inline_recordings.append((idx, item))
all_rows: List[List[Dict[str, Any]]] = [[] for _ in range(total_files)]
completed_count = 0
# Submit path-based tasks to the pool
future_to_idx: Dict[Any, int] = {}
pool_kwargs: Dict[str, Any] = {"max_workers": self.max_workers}
# Use spawn context on all platforms for process-safety with Qt/numpy
ctx = multiprocessing.get_context("spawn")
pool_kwargs["mp_context"] = ctx
with ProcessPoolExecutor(**pool_kwargs) as executor:
for orig_idx, file_path in path_tasks:
future = executor.submit(
_worker_process_file,
str(file_path),
pipeline_config,
channel_filter,
)
future_to_idx[future] = orig_idx
for future in as_completed(future_to_idx):
orig_idx = future_to_idx[future]
file_path = files[orig_idx]
file_name = Path(str(file_path)).name
completed_count += 1
try:
rows = future.result()
all_rows[orig_idx] = rows
except Exception as exc: # noqa: BLE001
log.error("Worker failed for %s: %s", file_path, exc, exc_info=True)
self._append_batch_error_log(file_name, str(file_path), exc)
all_rows[orig_idx] = [
{
"file_name": file_name,
"file_path": str(file_path),
"error": str(exc),
"debug_trace": traceback.format_exc(),
}
]
finally:
if progress_callback:
progress_callback(completed_count, total_files, f"Processed {file_name}")
if self._cancelled:
executor.shutdown(wait=False, cancel_futures=True)
break
# Process in-memory recordings sequentially (they can't be pickled reliably)
for orig_idx, recording in inline_recordings:
if self._cancelled:
break
completed_count += 1
file_name = getattr(getattr(recording, "source_file", None), "name", f"InMemory_{orig_idx}")
if progress_callback:
progress_callback(completed_count, total_files, f"Processing {file_name}...")
try:
df_inline = self._run_batch_sequential([recording], pipeline_config, None, channel_filter)
all_rows[orig_idx] = df_inline.to_dict("records") if not df_inline.empty else []
except Exception as exc: # noqa: BLE001
log.error("Inline recording failed: %s", exc, exc_info=True)
all_rows[orig_idx] = [
{"file_name": file_name, "error": str(exc), "debug_trace": traceback.format_exc()}
]
if progress_callback:
msg = "Batch cancelled." if self._cancelled else "Batch analysis complete."
progress_callback(total_files, total_files, msg)
flat_rows = [row for rows in all_rows for row in rows]
df = pd.DataFrame(flat_rows)
if not df.empty:
df["batch_timestamp"] = batch_start_time.isoformat()
df = self._order_columns(df)
return df
[docs]
def run_batch( # noqa: C901
self,
files: List[Union[Path, "Recording"]],
pipeline_config: List[Dict[str, Any]],
progress_callback: Optional[Callable[[int, int, str], None]] = None,
channel_filter: Optional[List[str]] = None,
rs_tolerance: float = 0.20,
cross_file_average: bool = False,
) -> pd.DataFrame:
"""
Run analysis on a list of files/recordings using a flexible pipeline configuration.
When :attr:`max_workers` > 1 **and** *files* contains at least two items, the
file-level loop is distributed across worker processes via
:class:`~concurrent.futures.ProcessPoolExecutor`. The GUI thread is never
blocked in either mode — callers should wrap this in a
:class:`~Synaptipy.application.gui.analysis_worker.BatchWorker` QThread.
Args:
files: List of file paths OR Recording objects to process.
pipeline_config: List of task dictionaries.
progress_callback: Optional callback (current, total, status_msg).
channel_filter: Optional list of channel names/IDs to process.
rs_tolerance: Maximum fractional increase in series resistance compared
to the first valid Rs measurement before a sweep is flagged with
``rs_qc_warning``. Default 0.20 (20 %). Set to ``float('inf')``
to disable the check.
Returns:
pandas DataFrame containing aggregated results with metadata.
"""
self._cancelled = False
total_files = len(files)
# Validate pipeline config
if not pipeline_config:
log.warning("Empty pipeline_config provided. No analyses will be run.")
return pd.DataFrame()
# Route to cross-file average mode when requested
if cross_file_average:
log.info("BatchAnalysisEngine: cross-file average mode enabled (%d files).", total_files)
return self._run_cross_file_average(files, pipeline_config, progress_callback, channel_filter)
# Route to parallel executor when max_workers > 1 and we have multiple files
if self.max_workers > 1 and total_files > 1:
log.info(
"BatchAnalysisEngine: starting parallel batch (%d workers, %d files).", self.max_workers, total_files
)
return self._run_batch_parallel(files, pipeline_config, progress_callback, channel_filter)
return self._run_batch_sequential(files, pipeline_config, progress_callback, channel_filter, rs_tolerance)
def _run_cross_file_average( # noqa: C901
self,
files: List[Union[Path, "Recording"]],
pipeline_config: List[Dict[str, Any]],
progress_callback: Optional[Callable[[int, int, str], None]],
channel_filter: Optional[List[str]],
) -> pd.DataFrame:
"""Aggregate all trials from all files per channel, compute the grand average,
then execute the pipeline ONCE per channel on that master trace.
The result DataFrame contains exactly one row per (channel, analysis) pair,
with ``file_name`` set to ``"CROSS_FILE_MASTER_AVERAGE"`` and
``trial_count`` reflecting the total number of trials pooled.
"""
batch_start_time = datetime.now()
total_files = len(files)
# ------------------------------------------------------------------
# Phase 1: collect all trial arrays per channel across every file
# ------------------------------------------------------------------
# {channel_name: {"trials": [...], "times": [...], "sampling_rate": float, "units": str}}
channel_data: Dict[str, Dict[str, Any]] = {}
for i, item in enumerate(files):
if self._cancelled:
break
file_name = "Unknown"
recording = None
try:
if isinstance(item, (str, Path)):
file_path = Path(item)
file_name = file_path.name
if progress_callback:
progress_callback(i, total_files, f"Loading {file_name}...")
recording = self.neo_adapter.read_recording(file_path, channel_whitelist=channel_filter)
if not recording:
log.warning("Cross-file avg: failed to load %s", file_path)
continue
else:
recording = item
src = getattr(recording, "source_file", None)
file_name = src.name if src else f"InMemory_{i}"
if progress_callback:
progress_callback(i, total_files, f"Loading {file_name}...")
channels_to_process = list(recording.channels.items())
if channel_filter:
channels_to_process = [
(k, ch) for k, ch in channels_to_process if k in channel_filter or str(k) in channel_filter
]
for channel_key, channel in channels_to_process:
native_name = getattr(channel, "name", None)
channel_name = native_name if native_name else channel_key
if channel_name not in channel_data:
channel_data[channel_name] = {
"trials": [],
"times": [],
"sampling_rate": channel.sampling_rate,
"units": getattr(channel, "units", "unknown"),
}
for trial_idx in range(channel.num_trials):
trial_data = channel.get_data(trial_idx)
trial_time = channel.get_relative_time_vector(trial_idx)
if trial_data is not None and trial_time is not None:
channel_data[channel_name]["trials"].append(trial_data)
channel_data[channel_name]["times"].append(trial_time)
except Exception as exc: # noqa: BLE001
log.error("Cross-file avg: error loading %s: %s", file_name, exc, exc_info=True)
finally:
recording = None
if progress_callback:
progress_callback(total_files, total_files, "Computing cross-file averages...")
# ------------------------------------------------------------------
# Phase 2: compute grand average per channel, run pipeline once
# ------------------------------------------------------------------
results_list: List[Dict[str, Any]] = []
for channel_name, ch_data in channel_data.items():
master_trial_list = ch_data["trials"]
if not master_trial_list:
continue
master_array = get_cross_file_average(master_trial_list)
if master_array is None:
log.warning("Cross-file avg: no valid trials for channel %s", channel_name)
continue
# Derive a reference time vector from the longest contributing trial
lengths = [len(t) for t in ch_data["times"]]
longest_idx = int(np.argmax(lengths))
master_time = ch_data["times"][longest_idx][: len(master_array)]
sampling_rate = ch_data["sampling_rate"]
trial_count = len(master_trial_list)
ch_meta: Dict[str, Any] = {
"file_name": "CROSS_FILE_MASTER_AVERAGE",
"file_path": f"CROSS_FILE_MASTER_AVERAGE ({total_files} files)",
"channel": channel_name,
"channel_units": ch_data["units"],
"trial_count": trial_count,
}
for task in pipeline_config:
if self._cancelled:
break
analysis_name = task.get("analysis")
params = task.get("params", {})
scope = task.get("scope", "average")
analysis_func = AnalysisRegistry.get_function(analysis_name)
if analysis_func is None:
results_list.append(
{
**ch_meta,
"analysis": analysis_name,
"scope": scope,
"error": f"Analysis '{analysis_name}' not registered",
}
)
continue
try:
p = {k: v for k, v in params.items() if k != "trial_index"}
res = analysis_func(master_array, master_time, sampling_rate, **p)
# Flatten consolidated-module schema
if "metrics" in res and isinstance(res.get("metrics"), dict):
metrics = res.pop("metrics")
res.update(metrics)
res.update(
{
**ch_meta,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
}
)
self._sanitise_result_for_export(res)
results_list.append(res)
except Exception as exc: # noqa: BLE001
log.error("Cross-file avg: analysis %s failed: %s", analysis_name, exc, exc_info=True)
results_list.append(
{
**ch_meta,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"error": str(exc),
"debug_trace": traceback.format_exc(),
}
)
if progress_callback:
msg = "Cross-file average cancelled." if self._cancelled else "Cross-file average complete."
progress_callback(total_files, total_files, msg)
df = pd.DataFrame(results_list)
if not df.empty:
df["batch_timestamp"] = batch_start_time.isoformat()
df = self._order_columns(df)
return df
def _run_batch_sequential( # noqa: C901
self,
files: List[Union[Path, "Recording"]],
pipeline_config: List[Dict[str, Any]],
progress_callback: Optional[Callable[[int, int, str], None]],
channel_filter: Optional[List[str]],
rs_tolerance: float = 0.20,
) -> pd.DataFrame:
"""Sequential (single-process) batch processing — the original implementation."""
results_list = []
total_files = len(files)
# Add batch metadata
batch_start_time = datetime.now()
for i, item in enumerate(files):
# Check for cancellation
if self._cancelled:
log.debug("Batch analysis cancelled by user.")
if progress_callback:
progress_callback(i, total_files, "Cancelled")
break
file_name = "Unknown"
file_path_str = "InMemory"
file_path = None # Initialize file_path
try:
# Determine if item is Path or Recording
recording = None
if isinstance(item, (str, Path)):
file_path = Path(item)
file_name = file_path.name
file_path_str = str(file_path)
if progress_callback:
progress_callback(i, total_files, f"Processing {file_name}...")
# Load recording from disk with whitelist (Memory Optimization)
recording = self.neo_adapter.read_recording(file_path, channel_whitelist=channel_filter)
if not recording:
log.warning(f"Failed to load {file_path}")
results_list.append(
{"file_name": file_name, "file_path": file_path_str, "error": "Failed to load recording"}
)
continue
else:
# Assume it is a Recording object
recording = item
if hasattr(recording, "source_file") and recording.source_file:
file_path = recording.source_file
file_name = recording.source_file.name
file_path_str = str(recording.source_file)
else:
# Fallback for purely in-memory recordings
file_path = Path(f"InMemory_Recording_{i}")
file_name = file_path.name
file_path_str = str(file_path)
if progress_callback:
progress_callback(i, total_files, f"Processing {file_name}...")
# Filter channels if specified
channels_to_process = recording.channels.items()
if channel_filter:
log.debug(f"Applying channel filter: {channel_filter}")
channels_to_process = [
(name, ch)
for name, ch in recording.channels.items()
if name in channel_filter or str(name) in channel_filter
]
if not channels_to_process:
log.warning(f"Channel filter {channel_filter} matched no channels in {file_name}.")
log.debug(f"Processing {len(channels_to_process)} channels: {[n for n, c in channels_to_process]}")
# Extract recording-level metadata once per file
rec_meta = self._recording_metadata(recording)
# Iterate through channels
for channel_key, channel in channels_to_process:
# Check for cancellation
if self._cancelled:
break
# Prefer the native channel name from the acquisition file header.
# Fall back to the channel key (ID) only when no name is available.
native_channel_name = getattr(channel, "name", None)
channel_name = native_channel_name if native_channel_name else channel_key
# Per-channel metadata available to every result row
ch_meta = {
"channel_units": getattr(channel, "units", "unknown"),
"trial_count": getattr(channel, "num_trials", 0),
}
ch_meta.update(rec_meta)
# Data Buffer for the pipeline (stores (data, time) tuples or lists)
pipeline_context = {
"scope": None, # Current scope of data in context
"data": None, # The data (array or list)
"time": None, # The time (array or list)
}
# Series-resistance stability tracker: reset for each new channel.
# rs_reference_mohm stores the Rs from the first valid sweep so
# all subsequent sweeps can be checked for drift.
rs_reference_mohm: Optional[float] = None
# Process each task in the pipeline
for task in pipeline_config:
if self._cancelled:
break
try:
# Pass the context to allow tasks to use/modify it
task_results, updated_context = self._process_task(
task=task,
channel=channel,
channel_name=channel_name,
file_path=file_path,
context=pipeline_context,
)
# Update context if the task modified it (e.g. preprocessing)
if updated_context:
pipeline_context = updated_context
# Enrich each result row with channel/recording metadata
for res in task_results:
for mk, mv in ch_meta.items():
res.setdefault(mk, mv)
# Sanitise for export (arrays → summaries, aliases)
self._sanitise_result_for_export(res)
# --- Series-Resistance Stability QC ---
# Track rs_mohm across trials. If Rs increases by more
# than rs_tolerance relative to Sweep 1, flag the row with
# a warning so analysts can exclude unstable patches.
for res in task_results:
rs_val = res.get("rs_mohm")
if rs_val is None:
continue
try:
rs_float = float(rs_val)
except (TypeError, ValueError):
continue
if rs_float != rs_float: # NaN guard
continue
if rs_reference_mohm is None:
rs_reference_mohm = rs_float
log.debug(
"Rs reference %.2f MOhm set for %s / %s.",
rs_float,
file_name,
channel_name,
)
elif rs_float > rs_reference_mohm * (1.0 + rs_tolerance):
delta_pct = (rs_float - rs_reference_mohm) / rs_reference_mohm * 100.0
log.warning(
"Series resistance destabilized: Rs=%.2f MOhm "
"(ref=%.2f MOhm, +%.1f%%) in %s / %s trial %s "
"(tolerance=%.0f%%).",
rs_float,
rs_reference_mohm,
delta_pct,
file_name,
channel_name,
res.get("trial_index", "?"),
rs_tolerance * 100.0,
)
res["rs_qc_warning"] = (
f"Series resistance destabilized: "
f"Rs={rs_float:.1f} MOhm (ref={rs_reference_mohm:.1f} "
f"MOhm, +{delta_pct:.1f}%)"
)
# Extend results list with all results from this task
results_list.extend(task_results)
except Exception as e: # noqa: BLE001 - broad catch intentional for fault-tolerance
log.error(
f"Error processing task {task.get('analysis', 'unknown')} on "
f"{file_path.name}/{channel_name}: {e}",
exc_info=True,
)
# Add error row — include full metadata for filtering
error_row = {
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": task.get("analysis", "unknown"),
"scope": task.get("scope", "unknown"),
"sampling_rate": getattr(channel, "sampling_rate", None),
"error": str(e),
"debug_trace": traceback.format_exc(),
}
error_row.update(ch_meta)
results_list.append(error_row)
continue
# Explicitly release the per-channel data buffer so NumPy arrays
# held in pipeline_context (which can be 10–200 MB for a long ABF)
# are freed before processing the next channel in this file.
pipeline_context = {"scope": None, "data": None, "time": None}
except Exception as e: # noqa: BLE001 - broad catch intentional; Domino Defense
# A single corrupted or unreadable file must never abort the entire batch run.
# Log the full traceback to batch_errors.log and continue to the next file.
log.error(f"Error processing batch file {file_path}: {e}", exc_info=True)
self._append_batch_error_log(file_name, file_path_str, e)
results_list.append(
{
"file_name": file_name,
"file_path": file_path_str,
"error": str(e),
"debug_trace": traceback.format_exc(),
}
)
continue
finally:
# Release the Recording object and collected data immediately after
# each file to prevent cumulative PySide6 / NumPy OOM in headless batch
# runs. gc.collect() ensures cyclic references are broken even when
# GC is otherwise disabled for test-mode offscreen stability.
recording = None # noqa: F841 # drop reference
# Aggressive memory management: collect garbage every 10 files and after each file
# if the results list is large (>500 rows), to prevent OOM on 8GB systems.
if (i + 1) % 10 == 0 or len(results_list) > 500:
gc.collect()
log.debug("gc.collect() called after processing item %d (results: %d rows).", i, len(results_list))
if progress_callback:
if self._cancelled:
progress_callback(i, total_files, "Batch analysis cancelled.")
else:
progress_callback(total_files, total_files, "Batch analysis complete.")
# Create DataFrame and add batch metadata
df = pd.DataFrame(results_list)
if not df.empty:
df["batch_timestamp"] = batch_start_time.isoformat()
df = self._order_columns(df)
return df
def _process_task( # noqa: C901
self, task: Dict[str, Any], channel, channel_name: str, file_path: Path, context: Dict[str, Any]
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
"""
Process a single analysis task on a channel, supporting preprocessing.
Args:
task: Task configuration dict
channel: Channel object
channel_name: Name/ID
file_path: Path
context: Current data context from previous steps
Returns:
Tuple: (List of results, Updated context or None)
"""
analysis_name = task.get("analysis")
scope = task.get("scope", "first_trial")
params = task.get("params", {})
# Check metadata for type and batch-dispatch flags
meta = AnalysisRegistry.get_metadata(analysis_name)
is_preprocessing = meta.get("type") == "preprocessing"
expects_list = meta.get("expects_list", False)
# Get the registered analysis function
analysis_func = AnalysisRegistry.get_function(analysis_name)
if analysis_func is None:
# Provide helpful suggestions using fuzzy string matching
available_analyses = AnalysisRegistry.list_analysis()
error_msg = f"Analysis function '{analysis_name}' not registered"
# Simple fuzzy matching: find analyses with similar names
from difflib import get_close_matches
suggestions = get_close_matches(analysis_name, available_analyses, n=3, cutoff=0.6)
if suggestions:
error_msg += f". Did you mean: {', '.join(suggestions)}?"
else:
error_msg += f". Available analyses: {', '.join(sorted(available_analyses)[:10])}"
if len(available_analyses) > 10:
error_msg += f" (and {len(available_analyses) - 10} more)"
log.error(error_msg)
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": error_msg,
}
], None
results = []
sampling_rate = channel.sampling_rate
# --- Data Retrieval Strategy ---
# 1. If context matches requested scope, use it.
# 2. If context exists but scope differs, try to adapt (e.g. average existing trials).
# 3. If no context, load from channel.
data = None
time = None
# Check if we can use context
if context["data"] is not None:
# If scope matches, use directly
if context["scope"] == scope:
data = context["data"]
time = context["time"]
# Adaptation: If we have 'all_trials' data but need 'average'
elif context["scope"] == "all_trials" and scope == "average":
# Compute average from cached trials. Guard against mixed-protocol
# files where trials can have different sample counts.
try:
if len(context["data"]) > 0:
trial_lengths = [len(a) for a in context["data"]]
lengths_set = set(trial_lengths)
if len(lengths_set) > 1:
# Build detailed error message showing which trials have which lengths
length_counts = {}
for i, length in enumerate(trial_lengths):
if length not in length_counts:
length_counts[length] = []
length_counts[length].append(i)
length_desc = ", ".join(
f"{length} samples (trials {','.join(map(str, trials))})"
for length, trials in sorted(length_counts.items())
)
raise ValueError(
f"Cannot average trials with mismatched lengths in "
f"{file_path.name}/{channel_name}: {length_desc}. "
"Use 'first_trial' or 'specific_trial' scope instead, or ensure "
"all sweeps use the same protocol duration."
)
data = np.mean(np.array(context["data"]), axis=0)
time = context["time"][0]
else:
log.warning("Context data empty, cannot average.")
except ValueError as e:
if "Cannot average" in str(e) or "mismatched lengths" in str(e):
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": "Cannot average mixed-length trials",
"error_type": "TRIAL_LENGTH_MISMATCH",
}
], None
log.warning(
"Could not average trials from context (%s/%s): %s. Reloading from source.",
file_path.name,
channel_name,
e,
)
except Exception as e:
log.warning(
"Could not average trials from context (%s/%s): %s. Reloading from source.",
file_path.name,
channel_name,
e,
)
# Adaptation: If we have 'all_trials' data but need 'selected_trials_average'
elif context["scope"] == "all_trials" and scope == "selected_trials_average":
try:
# Extract list of indices from task params, or default to all
trial_indices_str = params.get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
try:
parsed_indices = parse_trial_selection_string(
trial_indices_str, len(context["data"]), strict=True
)
selected_indices = sorted(list(parsed_indices))
except ValueError as e:
log.error(f"Invalid trial selection string in {file_path.name}/{channel_name}: {e}")
# Return early with error
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": str(e),
}
], None
else:
selected_indices = list(range(len(context["data"])))
if selected_indices:
selected_data = [context["data"][i] for i in selected_indices if i < len(context["data"])]
lengths = {len(a) for a in selected_data}
if len(lengths) > 1:
raise ValueError(
f"Cannot average selected trials with mismatched lengths "
f"{sorted(lengths)} in {file_path.name}/{channel_name}. "
"Ensure all selected sweeps use the same protocol duration."
)
data = np.mean(np.array(selected_data), axis=0)
time = context["time"][0]
else:
log.warning("No valid trials selected for averaging from context.")
except ValueError as e:
if "Cannot average" in str(e) or "mismatched lengths" in str(e):
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": "Cannot average mixed-length trials",
"error_type": "TRIAL_LENGTH_MISMATCH",
}
], None
log.warning(
"Could not average selected trials from context (%s/%s): %s. Reloading from source.",
file_path.name,
channel_name,
e,
)
except Exception as e:
log.warning(
"Could not average selected trials from context (%s/%s): %s. Reloading from source.",
file_path.name,
channel_name,
e,
)
# If data is still None, load from channel
if data is None:
# Validate scope against available data
if scope in ("all_trials", "selected_trials", "channel_set") and channel.num_trials == 0:
log.error(
f"Scope '{scope}' requires trials, but channel {channel_name} in "
f"{file_path.name} has no trials loaded."
)
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": f"Scope '{scope}' requires trials but channel has no trials",
}
], None
if scope == "average":
data = channel.get_averaged_data()
time = channel.get_relative_averaged_time_vector()
elif scope == "all_trials":
data = []
time = []
for i in range(channel.num_trials):
d = channel.get_data(i)
t = channel.get_relative_time_vector(i)
if d is not None:
data.append(d)
time.append(t)
# If loading raw, we might want to update context if this was a heavy load?
# For now, only update context if preprocessing occurs.
elif scope == "selected_trials":
data = []
time = []
trial_indices_str = params.get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
try:
parsed_indices = parse_trial_selection_string(
trial_indices_str, channel.num_trials, strict=True
)
selected_indices = sorted(list(parsed_indices))
except ValueError as e:
log.error(f"Invalid trial selection string in {file_path.name}/{channel_name}: {e}")
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": str(e),
}
], None
else:
selected_indices = list(range(channel.num_trials))
for i in selected_indices:
d = channel.get_data(i)
t = channel.get_relative_time_vector(i)
if d is not None:
data.append(d)
time.append(t)
elif scope == "selected_trials_average":
trial_indices_str = params.get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
try:
parsed_indices = parse_trial_selection_string(
trial_indices_str, channel.num_trials, strict=True
)
selected_indices = sorted(list(parsed_indices))
except ValueError as e:
log.error(f"Invalid trial selection string in {file_path.name}/{channel_name}: {e}")
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": str(e),
}
], None
else:
selected_indices = None
data = channel.get_averaged_data(trial_indices=selected_indices)
time = channel.get_relative_averaged_time_vector()
elif scope == "first_trial":
data = channel.get_data(0)
time = channel.get_relative_time_vector(0)
elif scope == "specific_trial":
idx = int(params.get("trial_index", 0))
data = channel.get_data(idx)
time = channel.get_relative_time_vector(idx)
elif scope == "channel_set":
# channel_set usually implies list of all trials
data = []
time = []
for i in range(channel.num_trials):
d = channel.get_data(i)
t = channel.get_relative_time_vector(i)
if d is not None:
data.append(d)
time.append(t)
# Validation
if data is None or (isinstance(data, list) and len(data) == 0):
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"error": "No data available",
}
], None
# --- expects_list enforcement ---
# Controls how multi-trial data is dispatched to the analysis function
# when scope="all_trials":
# expects_list=False (default): iterate - call once per trial (existing behaviour)
# expects_list=True: pass the complete list in a single call (like channel_set)
# No data transformation is applied here; the flag is used by the
# execution dispatcher below.
# --- Execution ---
if is_preprocessing:
# Preprocessing: Modify data and return new context
# Store original context to restore on failure (prevents contamination)
original_context = context.copy() if context else None
try:
# Preprocessing functions typically take (data, time, fs, **kwargs)
# and return modified data.
# If scope is 'all_trials', we might need to iterate if the func expects single trace.
# Heuristic: Check if data is list (multiple trials)
if isinstance(data, list):
# Apply to each item
new_data = []
new_time = []
for d, t in zip(data, time):
# Filter might modify data or time? Usually just data.
# Some filters might return (data, time) tuple?
# Let's assume standard signature returns just data for now,
# or we check return type.
res = analysis_func(d, t, sampling_rate, **params)
new_data.append(res)
new_time.append(t) # Assume time unchanged
modified_data = new_data
modified_time = new_time
else:
# Single trace
modified_data = analysis_func(data, time, sampling_rate, **params)
modified_time = time
# Return empty results, but updated context
new_context = {"scope": scope, "data": modified_data, "time": modified_time}
return [], new_context
except Exception as e:
log.error(f"Preprocessing failed: {e}", exc_info=True)
# Restore original context to prevent contamination of subsequent tasks
error_row = {
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"error": f"Preprocessing failed: {e}",
"debug_trace": traceback.format_exc(),
}
return [error_row], original_context
else:
# Standard Analysis
try:
# Helper to run analysis and format result
total_trials = getattr(channel, "num_trials", 0)
def run_single(d, t, trial_idx=None):
# Remove trial_index from params if present
p = params.copy()
p.pop("trial_index", None)
res = analysis_func(d, t, sampling_rate, **p)
# Flatten consolidated-module schema: {"module_used": ..., "metrics": {...}}
if "metrics" in res and isinstance(res.get("metrics"), dict):
metrics = res.pop("metrics")
res.update(metrics)
# Add metadata
res.update(
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"trial_count": total_trials,
}
)
if trial_idx is not None:
res["trial_index"] = trial_idx
return res
if scope == "all_trials" or scope == "channel_set":
# For channel_set, some functions expect the list (e.g. F-I curve)
# others expect iteration.
# Check if function handles list?
# NOTE: Original code treated 'channel_set' as passing the list to func.
# 'all_trials' iterated.
if scope == "channel_set" or (scope == "all_trials" and expects_list):
# Pass full list (channel_set always, all_trials when expects_list=True)
res = analysis_func(data, time, sampling_rate, **params)
# Flatten consolidated-module schema: {"module_used": ..., "metrics": {...}}
if "metrics" in res and isinstance(res.get("metrics"), dict):
metrics = res.pop("metrics")
res.update(metrics)
res.update(
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"trial_count": len(data) if isinstance(data, list) else 1,
}
)
results.append(res)
else:
# Iterate 'all_trials' or 'selected_trials'
# For 'selected_trials', extract the specific indices used
if scope == "selected_trials":
trial_indices_str = task.get("params", {}).get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
try:
parsed_indices = parse_trial_selection_string(
trial_indices_str, total_trials, strict=True
)
indices_list = sorted(list(parsed_indices))
except ValueError as e:
log.error(f"Invalid trial selection string in {file_path.name}/{channel_name}: {e}")
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": str(e),
}
], None
else:
indices_list = list(range(total_trials))
else:
indices_list = list(range(total_trials))
for i, (d, t) in enumerate(zip(data, time)):
# Ensure we output correct trial index
real_idx = indices_list[i] if i < len(indices_list) else i
results.append(run_single(d, t, real_idx))
elif scope == "specific_trial":
idx = int(params.get("trial_index", 0))
results.append(run_single(data, time, idx))
else:
# Single trace (average, first_trial)
results.append(run_single(data, time))
return results, None # No context update
except Exception as e:
log.error(f"Analysis failed: {e}", exc_info=True)
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"error": f"Analysis failed: {e}",
"debug_trace": traceback.format_exc(),
}
], None
# ---------------------------------------------------------------------------
# Module-level worker function for ProcessPoolExecutor
# ---------------------------------------------------------------------------
def _worker_process_file(
file_path_str: str,
pipeline_config: List[Dict[str, Any]],
channel_filter: Optional[List[str]],
) -> List[Dict[str, Any]]:
"""Process a single file in an isolated worker process.
This function is called by :class:`~concurrent.futures.ProcessPoolExecutor`
in a freshly spawned process. It re-imports the full analysis package so
that all ``@AnalysisRegistry.register`` decorators execute, then delegates
to :class:`BatchAnalysisEngine` with ``max_workers=1`` (sequential) to avoid
recursive parallelism.
OOM safety: ``gc.collect()`` is called explicitly after processing.
Args:
file_path_str: Absolute path to the recording file (str, pickle-safe).
pipeline_config: Serialised pipeline task list.
channel_filter: Optional channel whitelist.
Returns:
List of result-row dicts ready for ``pd.DataFrame()``.
"""
import gc as _gc # avoid shadowing module-level gc import
from pathlib import Path as _Path
# Trigger all @AnalysisRegistry.register decorators in this new process
import Synaptipy.core.analysis # noqa: F401,F811
engine = BatchAnalysisEngine(max_workers=1)
try:
df = engine._run_batch_sequential(
[_Path(file_path_str)],
pipeline_config,
None, # progress_callback not serialisable across processes
channel_filter,
)
return df.to_dict("records") if not df.empty else []
finally:
_gc.collect()