diff --git a/pyproject.toml b/pyproject.toml index ead0098a..56d9ce7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,9 +43,6 @@ dependencies = [ ] [project.optional-dependencies] -torch = [ - "torch>=2.0" -] gui = [ "pyqtgraph>=0.13.7", "PySide6>=6.6", @@ -66,7 +63,6 @@ docs = [ "sphinxcontrib-spelling>=7.1.0" ] all = [ - "eegprep[torch]", "eegprep[gui]", "eegprep[console]", "eegprep[docs]", diff --git a/src/eegprep/__init__.py b/src/eegprep/__init__.py index 52ad21d1..e129b918 100644 --- a/src/eegprep/__init__.py +++ b/src/eegprep/__init__.py @@ -47,7 +47,6 @@ "ExtensionStatus": ("eegprep.extensions", "ExtensionStatus"), "ExtensionTestHarness": ("eegprep.extension_testing", "ExtensionTestHarness"), "ExtensionValidationResult": ("eegprep.extensions", "ExtensionValidationResult"), - "ICL_feature_extractor": ("eegprep.plugins.ICLabel.ICL_feature_extractor", "ICL_feature_extractor"), "LazyImport": ("eegprep.extensions", "LazyImport"), "assert_extension_entry_point_loads": ("eegprep.extension_testing", "assert_extension_entry_point_loads"), "bids_list_eeg_files": ("eegprep.plugins.EEG_BIDS.bids_list_eeg_files", "bids_list_eeg_files"), @@ -59,16 +58,7 @@ "chancenter": ("eegprep.functions.sigprocfunc.chancenter", "chancenter"), "check_extension_compatibility": ("eegprep.extensions", "check_extension_compatibility"), "checkset": ("eegprep.functions.redefine_functions", "checkset"), - "clean_artifacts": ("eegprep.plugins.clean_rawdata.clean_artifacts", "clean_artifacts"), - "clean_asr": ("eegprep.plugins.clean_rawdata.clean_asr", "clean_asr"), - "clean_channels": ("eegprep.plugins.clean_rawdata.clean_channels", "clean_channels"), - "clean_channels_nolocs": ("eegprep.plugins.clean_rawdata.clean_channels_nolocs", "clean_channels_nolocs"), - "clean_drifts": ("eegprep.plugins.clean_rawdata.clean_drifts", "clean_drifts"), - "clean_flatlines": ("eegprep.plugins.clean_rawdata.clean_flatlines", "clean_flatlines"), - "clean_windows": ("eegprep.plugins.clean_rawdata.clean_windows", "clean_windows"), - "clean_rawdata_vis_artifacts": ("eegprep.plugins.clean_rawdata.vis_artifacts", "vis_artifacts"), "clean_rawdata_vis_artifacts_diagnostics": ( - "eegprep.plugins.clean_rawdata.vis_artifacts", "vis_artifacts_diagnostics", ), "compare": ("eegprep.functions.redefine_functions", "compare"), @@ -79,16 +69,12 @@ "discover_extensions": ("eegprep.extensions", "discover_extensions"), "eeg2mne": ("eegprep.functions.redefine_functions", "eeg2mne"), "eeg_amica": ("eegprep.functions.popfunc.eeg_amica", "eeg_amica"), - "eeg_autocorr": ("eegprep.plugins.ICLabel.eeg_autocorr", "eeg_autocorr"), - "eeg_autocorr_fftw": ("eegprep.plugins.ICLabel.eeg_autocorr_fftw", "eeg_autocorr_fftw"), - "eeg_autocorr_welch": ("eegprep.plugins.ICLabel.eeg_autocorr_welch", "eeg_autocorr_welch"), "eeg_checkset": ("eegprep.functions.adminfunc.eeg_checkset", "eeg_checkset"), "eeg_checkset_strict_mode": ("eegprep.functions.adminfunc.eeg_checkset", "strict_mode"), "eeg_compare": ("eegprep.functions.popfunc.eeg_compare", "eeg_compare"), "eeg_decodechan": ("eegprep.functions.popfunc.eeg_decodechan", "eeg_decodechan"), "eeg_eeg2mne": ("eegprep.functions.miscfunc.eeg_eeg2mne", "eeg_eeg2mne"), "eeg_eegrej": ("eegprep.functions.popfunc.eeg_eegrej", "eeg_eegrej"), - "eeg_icalabelstat": ("eegprep.plugins.ICLabel.eeg_icalabelstat", "eeg_icalabelstat"), "eeg_emptyset": ("eegprep.functions.popfunc.eeg_emptyset", "eeg_emptyset"), "eeg_multieegplot": ("eegprep.functions.popfunc.eeg_multieegplot", "eeg_multieegplot"), "eegplot": ("eegprep.functions.sigprocfunc.eegplot", "eegplot"), @@ -104,7 +90,6 @@ "eeg_pvaf": ("eegprep.functions.sigprocfunc.ica_helpers", "eeg_pvaf"), "eeg_rejsuperpose": ("eegprep.functions.popfunc.eeg_rejsuperpose", "eeg_rejsuperpose"), "eeg_retrieve": ("eegprep.functions.adminfunc.eeg_retrieve", "eeg_retrieve"), - "eeg_rpsd": ("eegprep.plugins.ICLabel.eeg_rpsd", "eeg_rpsd"), "eeg_runica": ("eegprep.functions.popfunc.eeg_runica", "eeg_runica"), "eeg_store": ("eegprep.functions.adminfunc.eeg_store", "eeg_store"), "eegh": ("eegprep.functions.adminfunc.eegh", "eegh"), @@ -131,8 +116,6 @@ "icaact": ("eegprep.functions.sigprocfunc.ica_helpers", "icaact"), "icaproj": ("eegprep.functions.sigprocfunc.ica_helpers", "icaproj"), "icavar": ("eegprep.functions.sigprocfunc.ica_helpers", "icavar"), - "iclabel": ("eegprep.plugins.ICLabel.iclabel", "iclabel"), - "eeg_icflag": ("eegprep.plugins.ICLabel.eeg_icflag", "eeg_icflag"), "inputgui": ("eegprep.functions.guifunc.inputgui", "inputgui"), "interp": ("eegprep.functions.redefine_functions", "interp"), "jointprob": ("eegprep.functions.sigprocfunc.jointprob", "jointprob"), @@ -168,8 +151,6 @@ "rspdfsolv": ("eegprep.functions.timefreqfunc.rspdfsolv", "rspdfsolv"), "rspfunc": ("eegprep.functions.timefreqfunc.rspfunc", "rspfunc"), "tf_cycle_calc": ("eegprep.functions.timefreqfunc.tf_cycle_calc", "tf_cycle_calc"), - "vis_artifacts": ("eegprep.plugins.clean_rawdata.vis_artifacts", "vis_artifacts"), - "vis_artifacts_diagnostics": ("eegprep.plugins.clean_rawdata.vis_artifacts", "vis_artifacts_diagnostics"), "options": ("eegprep.functions.redefine_functions", "options"), "picard": ("eegprep.functions.redefine_functions", "picard"), "plugin_menu": ("eegprep.functions.adminfunc.plugin_menu", "plugin_menu"), @@ -186,7 +167,6 @@ "pop_chanplot": ("eegprep.functions.studyfunc.pop_chanplot", "pop_chanplot"), "pop_chancoresp": ("eegprep.functions.popfunc.pop_chancoresp", "pop_chancoresp"), "pop_chansel": ("eegprep.functions.popfunc.pop_chansel", "pop_chansel"), - "pop_clean_rawdata": ("eegprep.plugins.clean_rawdata.pop_clean_rawdata", "pop_clean_rawdata"), "pop_chanedit": ("eegprep.functions.popfunc.pop_chanedit", "pop_chanedit"), "pop_comments": ("eegprep.functions.popfunc.pop_comments", "pop_comments"), "pop_clust": ("eegprep.functions.studyfunc.pop_clust", "pop_clust"), @@ -234,8 +214,6 @@ "pop_firwsord": ("eegprep.plugins.firfilt.pop_firwsord", "pop_firwsord"), "pop_fusechanrej": ("eegprep.functions.popfunc.pop_fusechanrej", "pop_fusechanrej"), "pop_headplot": ("eegprep.functions.popfunc.pop_headplot", "pop_headplot"), - "pop_icflag": ("eegprep.plugins.ICLabel.pop_icflag", "pop_icflag"), - "pop_iclabel": ("eegprep.plugins.ICLabel.pop_iclabel", "pop_iclabel"), "pop_icathresh": ("eegprep.functions.popfunc.pop_icathresh", "pop_icathresh"), "pop_jointprob": ("eegprep.functions.popfunc.pop_jointprob", "pop_jointprob"), "pop_kaiserbeta": ("eegprep.plugins.firfilt.pop_kaiserbeta", "pop_kaiserbeta"), @@ -266,7 +244,6 @@ "pop_preclust": ("eegprep.functions.studyfunc.pop_preclust", "pop_preclust"), "pop_precomp": ("eegprep.functions.studyfunc.pop_precomp", "pop_precomp"), "pop_prop": ("eegprep.functions.popfunc.pop_prop", "pop_prop"), - "pop_prop_extended": ("eegprep.plugins.ICLabel.pop_prop_extended", "pop_prop_extended"), "pop_rejchan": ("eegprep.functions.popfunc.pop_rejchan", "pop_rejchan"), "pop_rejcont": ("eegprep.functions.popfunc.pop_rejcont", "pop_rejcont"), "pop_rejepoch": ("eegprep.functions.popfunc.pop_rejepoch", "pop_rejepoch"), @@ -300,7 +277,6 @@ "pop_timef": ("eegprep.functions.popfunc.pop_timef", "pop_timef"), "pop_topochansel": ("eegprep.functions.popfunc.pop_topochansel", "pop_topochansel"), "pop_topoplot": ("eegprep.functions.popfunc.pop_topoplot", "pop_topoplot"), - "pop_viewprops": ("eegprep.plugins.ICLabel.pop_viewprops", "pop_viewprops"), "pop_writeeeg": ("eegprep.functions.popfunc.pop_writeeeg", "pop_writeeeg"), "pop_writelocs": ("eegprep.functions.popfunc.pop_writelocs", "pop_writelocs"), "pop_xfirws": ("eegprep.plugins.firfilt.pop_xfirws", "pop_xfirws"), @@ -400,8 +376,35 @@ __all__ = ["__version__", *_LAZY_EXPORTS] +_DECOUPLED_PLUGINS = { + "pop_iclabel": "eegprep-iclabel", + "iclabel": "eegprep-iclabel", + "pop_icflag": "eegprep-iclabel", + "eeg_icalabelstat": "eegprep-iclabel", + "eeg_icflag": "eegprep-iclabel", + "ICL_feature_extractor": "eegprep-iclabel", + "pop_viewprops": "eegprep-iclabel", + "pop_prop_extended": "eegprep-iclabel", + "pop_clean_rawdata": "eegprep-clean-rawdata", + "clean_artifacts": "eegprep-clean-rawdata", + "clean_asr": "eegprep-clean-rawdata", + "clean_channels": "eegprep-clean-rawdata", + "clean_channels_nolocs": "eegprep-clean-rawdata", + "clean_drifts": "eegprep-clean-rawdata", + "clean_flatlines": "eegprep-clean-rawdata", + "clean_windows": "eegprep-clean-rawdata", + "vis_artifacts": "eegprep-clean-rawdata", +} + def __getattr__(name: str) -> Any: """Load public EEGPrep exports on first access.""" + if name in _DECOUPLED_PLUGINS: + pkg = _DECOUPLED_PLUGINS[name] + raise RuntimeError( + f"The feature {name!r} requires an uninstalled plugin package. " + f"Please install it using: pip install {pkg}" + ) + try: module_name, attr_name = _LAZY_EXPORTS[name] except KeyError as exc: diff --git a/src/eegprep/cli/commands/transforms.py b/src/eegprep/cli/commands/transforms.py index 368a6ae4..f1fdcded 100644 --- a/src/eegprep/cli/commands/transforms.py +++ b/src/eegprep/cli/commands/transforms.py @@ -31,7 +31,6 @@ from eegprep.functions.popfunc.pop_resample import pop_resample from eegprep.functions.popfunc.pop_runica import pop_runica from eegprep.functions.popfunc.pop_saveset import pop_saveset -from eegprep.plugins.clean_rawdata.pop_clean_rawdata import pop_clean_rawdata from eegprep.plugins.firfilt.pop_eegfiltnew import pop_eegfiltnew diff --git a/src/eegprep/extensions.py b/src/eegprep/extensions.py index cc91311a..6bfeb204 100644 --- a/src/eegprep/extensions.py +++ b/src/eegprep/extensions.py @@ -19,8 +19,6 @@ eeg_bids_import_items, eeg_bids_tools_menu, ) -from eegprep.plugins.ICLabel.menu import iclabel_menu, viewprops_plot_menus -from eegprep.plugins.clean_rawdata.menu import clean_rawdata_menu from eegprep.plugins.dipfit.menu import dipfit_menu from eegprep.plugins.firfilt.menu import firfilt_filter_items @@ -328,62 +326,6 @@ def get(self, name: str) -> ExtensionRecord | None: def _bundled_records(self) -> list[ExtensionRecord]: specs = ( - ExtensionSpec( - name="clean_rawdata", - display_name="clean_rawdata", - version="bundled", - package_name="eegprep.plugins.clean_rawdata", - source_type=ExtensionSourceType.BUNDLED, - description="Artifact Subspace Reconstruction and related channel/window cleaning workflows.", - capabilities=("artifact", "preprocessing"), - menus=( - _extension_menu_from_spec( - clean_rawdata_menu(), - path=("tools",), - insert_after="pop_eegplot:data", - ), - ), - pop_functions=( - ExtensionPopFunction( - name="pop_clean_rawdata", - target=LazyImport( - "eegprep.plugins.clean_rawdata.pop_clean_rawdata", - "pop_clean_rawdata", - ), - ), - ), - ), - ExtensionSpec( - name="ICLabel", - display_name="ICLabel", - version="bundled", - package_name="eegprep.plugins.ICLabel", - source_type=ExtensionSourceType.BUNDLED, - description="Independent-component classification, flagging, and extended component properties.", - capabilities=("ica", "classification"), - menus=( - _extension_menu_from_spec( - iclabel_menu(), - path=("tools",), - insert_after="pop_selectcomps", - ), - *(_extension_menu_from_spec(item, path=("plot",)) for item in viewprops_plot_menus()), - ), - pop_functions=( - ExtensionPopFunction( - name="pop_iclabel", - target=LazyImport("eegprep.plugins.ICLabel.pop_iclabel", "pop_iclabel"), - ), - ExtensionPopFunction( - name="pop_icflag", - target=LazyImport("eegprep.plugins.ICLabel.pop_icflag", "pop_icflag"), - ), - ExtensionPopFunction( - name="pop_viewprops", - target=LazyImport("eegprep.plugins.ICLabel.pop_viewprops", "pop_viewprops"), - ), - ), - ), ExtensionSpec( name="firfilt", display_name="firfilt", diff --git a/src/eegprep/functions/guifunc/menu_actions.py b/src/eegprep/functions/guifunc/menu_actions.py index 71f3cc84..1b6171be 100644 --- a/src/eegprep/functions/guifunc/menu_actions.py +++ b/src/eegprep/functions/guifunc/menu_actions.py @@ -40,7 +40,6 @@ "pop_chanplot", "pop_clust", "pop_clustedit", - "pop_clean_rawdata", "pop_chanedit", "pop_comperp", "pop_delset", @@ -78,8 +77,6 @@ "pop_firpm", "pop_firws", "pop_headplot", - "pop_icflag", - "pop_iclabel", "pop_jointprob", "pop_leadfield", "pop_importbids", @@ -125,7 +122,6 @@ "pop_timtopo", "pop_taskinfo", "pop_topoplot", - "pop_viewprops", "pop_participantinfo", "pop_mergeset", "pop_multifit", @@ -150,7 +146,6 @@ EEGPREP_SOURCE_URL = f"{EEGPREP_REPO_URL}/blob/develop" _MULTIPLE_DATASET_ACTIONS = { - "pop_clean_rawdata", "pop_chanedit", "pop_eegfilt", "pop_eegfiltnew", @@ -158,8 +153,6 @@ "pop_firma", "pop_firpm", "pop_firws", - "pop_icflag", - "pop_iclabel", "pop_reref", "pop_rmdat", "pop_resample", @@ -198,7 +191,6 @@ } _NEWSET_COMMIT_ACTIONS = { - "pop_clean_rawdata", "pop_eegfilt", "pop_eegfiltnew", "pop_epoch", @@ -225,7 +217,6 @@ "pop_editeventfield", "pop_editeventvals", "pop_chanedit", - "pop_clean_rawdata", "pop_eegfilt", "pop_eegfiltnew", "pop_epoch", @@ -240,8 +231,6 @@ "pop_runica", "pop_select", "pop_selectevent", - "pop_iclabel", - "pop_icflag", "pop_subcomp", } @@ -400,7 +389,6 @@ def dispatch(self, action: str, parent: Any | None = None) -> None: "pop_rejspec", "pop_rejtrend", "pop_selectcomps", - "pop_viewprops", }: self._run_pop_function(base, parent, variant=variant) return @@ -1065,10 +1053,6 @@ def _run_pop_function(self, name: str, parent: Any | None, *, variant: str = "") from eegprep.functions.popfunc.pop_chanedit import pop_chanedit out = pop_chanedit(selection, return_com=True) - elif name == "pop_clean_rawdata": - from eegprep.plugins.clean_rawdata.pop_clean_rawdata import pop_clean_rawdata - - out = pop_clean_rawdata(selection, return_com=True) elif name == "pop_eegfilt": from eegprep.functions.popfunc.pop_eegfilt import pop_eegfilt @@ -1101,14 +1085,6 @@ def _run_pop_function(self, name: str, parent: Any | None, *, variant: str = "") from eegprep.functions.popfunc.pop_interp import pop_interp out = pop_interp(selection, alleeg=self.session.ALLEEG, return_com=True) - elif name == "pop_iclabel": - from eegprep.plugins.ICLabel.pop_iclabel import pop_iclabel - - out = pop_iclabel(selection, return_com=True) - elif name == "pop_icflag": - from eegprep.plugins.ICLabel.pop_icflag import pop_icflag - - out = pop_icflag(selection, return_com=True) elif name == "pop_resample": from eegprep.functions.popfunc.pop_resample import pop_resample @@ -1238,22 +1214,7 @@ def accept_browser_result(eeg_out: Any, command: str) -> None: from eegprep.functions.popfunc.pop_selectcomps import pop_selectcomps out = pop_selectcomps(selection, return_com=True) - elif name == "pop_viewprops": - from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops - - target_index = list(self.session.CURRENTSET) - - def commit_component_rejection(eeg_out: Any, _states: dict[int, bool]) -> None: - with self.session.gui_action("pop_viewprops"): - self._store_current_from_gui(eeg_out, command="", index=target_index) - self._refresh() - out = pop_viewprops( - selection, - typecomp=0 if variant == "components" else 1, - reject_callback=commit_component_rejection, - return_com=True, - ) else: self.show_coming_soon(name, parent) return @@ -1261,10 +1222,6 @@ def commit_component_rejection(eeg_out: Any, _states: dict[int, bool]) -> None: eeg_out, command = out[0], out[1] if len(out) > 1 else "" else: eeg_out, command = out, "" - if name == "pop_viewprops": - self._add_history_from_gui(command) - self._refresh() - return if command: if name in _NEWSET_COMMIT_ACTIONS: self._commit_processed_dataset_from_gui(eeg_out, command=command, parent=parent) diff --git a/src/eegprep/functions/sigprocfunc/runica.py b/src/eegprep/functions/sigprocfunc/runica.py index 5c23534b..5e081aef 100644 --- a/src/eegprep/functions/sigprocfunc/runica.py +++ b/src/eegprep/functions/sigprocfunc/runica.py @@ -24,11 +24,18 @@ import numpy as np from scipy.linalg import sqrtm, pinv, eig -from ...plugins.clean_rawdata.private.ransac import rand_permutation -from ..miscfunc.misc import finite_pinv +from ..miscfunc.misc import finite_pinv, round_mat logger = logging.getLogger(__name__) +def rand_permutation(n: int, stream: np.random.RandomState) -> np.ndarray: + """Random permutation with MATLAB parity using Fisher-Yates shuffle.""" + result = np.arange(n) + for k in range(n - 1, 0, -1): + j = int(round_mat(k * stream.rand())) + result[k], result[j] = result[j], result[k] + return result + def _matmul(left, right): # MATLAB mtimes does not surface BLAS floating-point status warnings for diff --git a/src/eegprep/plugins/ICLabel/ICL_feature_extractor.py b/src/eegprep/plugins/ICLabel/ICL_feature_extractor.py deleted file mode 100644 index 88bdd5aa..00000000 --- a/src/eegprep/plugins/ICLabel/ICL_feature_extractor.py +++ /dev/null @@ -1,104 +0,0 @@ -"""ICLabel feature extraction functions.""" - -from copy import deepcopy -import numpy as np - - -def ICL_feature_extractor(EEG, flag_autocorr=False): - """Extract features for ICLabel classification. - - Parameters - ---------- - EEG : dict - EEG data structure with ICA - flag_autocorr : bool, optional - Whether to include autocorrelation features (default False) - - Returns - ------- - features : list - List of feature arrays - """ - from eegprep import topoplot - from eegprep import eeg_rpsd - from eegprep import eeg_autocorr_welch - from eegprep import eeg_autocorr - from eegprep import eeg_autocorr_fftw - from eegprep import pop_reref - - EEG = deepcopy(EEG) - - # Check for ICA key and if it is not empty before dereferencing it - if 'icawinv' not in EEG.keys() or EEG['icawinv'].size == 0: - raise ValueError('You must have an ICA decomposition to use ICLabel') - - ncomp = EEG['icawinv'].shape[1] - - # Assuming chanlocs are correct - if EEG.get('ref') != 'average' and EEG.get('ref') != 'averef': - EEG = pop_reref(EEG, []) - - # Calculate ICA activations if missing and cast to double - if EEG['icaact'] is None: - raise ValueError('You must have ICA activations to use ICLabel') - # EEG['icaact'] = eeg_getica(EEG) - - EEG['icaact'] = EEG['icaact'].astype(float) - - # Check ICA is real - assert np.isreal(EEG['icaact']).all(), 'Your ICA decomposition must be real to use ICLabel' - - # Calculate topo - topo = np.zeros((32, 32, 1, ncomp)) - for it in range(ncomp): - tmp_chanlocs = [EEG['chanlocs'][i] for i in EEG['icachansind']] - _, temp_topo, _, _, _ = topoplot(EEG['icawinv'][:, it], tmp_chanlocs, noplot='on', gridscale=32) - temp_topo[np.isnan(temp_topo)] = 0 - topo[:, :, 0, it] = temp_topo / np.max(np.abs(temp_topo)) - - # Cast - topo = topo.astype(np.float32) - - # Calculate PSD - psd = eeg_rpsd(EEG, 100) - - # Extrapolate or prune as needed - nfreq = psd.shape[1] - if nfreq < 100: - psd = np.hstack((psd, np.tile(psd[:, -1][:, np.newaxis], (1, 100 - nfreq)))) - - # Undo notch filter - for linenoise_ind in [50, 60]: - linenoise_around = [linenoise_ind - 1, linenoise_ind + 1] - difference = psd[:, linenoise_around] - psd[:, linenoise_ind][:, np.newaxis] - notch_ind = np.all(difference > 5, axis=1) - if np.any(notch_ind): - psd[notch_ind, linenoise_ind] = np.mean(psd[notch_ind][:, linenoise_around], axis=1) - - # Normalize - psd = psd / np.max(np.abs(psd), axis=1)[:, np.newaxis] - - psd = np.transpose(np.expand_dims(np.expand_dims(psd, axis=-1), axis=-1), (2, 1, 3, 0)).astype(np.float32) - - # Calculate autocorrelation - if flag_autocorr: - if EEG['trials'] == 1: - if EEG['pnts'] / EEG['srate'] > 5: - autocorr = eeg_autocorr_welch(EEG) - else: - autocorr = eeg_autocorr(EEG) - else: - autocorr = eeg_autocorr_fftw(EEG) - - # Reshape and cast - autocorr = np.transpose(np.expand_dims(np.expand_dims(autocorr, axis=-1), axis=-1), (2, 1, 3, 0)).astype( - np.float32 - ) - - # Format outputs - if flag_autocorr: - features = [0.99 * topo, 0.99 * psd, 0.99 * autocorr] - else: - features = [0.99 * topo, 0.99 * psd] - - return features diff --git a/src/eegprep/plugins/ICLabel/__init__.py b/src/eegprep/plugins/ICLabel/__init__.py deleted file mode 100644 index f90a2a48..00000000 --- a/src/eegprep/plugins/ICLabel/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""ICLabel plugin ports.""" diff --git a/src/eegprep/plugins/ICLabel/_prop_browser.py b/src/eegprep/plugins/ICLabel/_prop_browser.py deleted file mode 100644 index 36143b87..00000000 --- a/src/eegprep/plugins/ICLabel/_prop_browser.py +++ /dev/null @@ -1,778 +0,0 @@ -"""Matplotlib rendering for ICLabel extended property dashboards.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from typing import Any - -from matplotlib.widgets import Button -import matplotlib.pyplot as plt -import numpy as np - -from eegprep.functions.guifunc.pophelp import pophelp -from eegprep.functions.popfunc._property_browser import property_activity_browser -from eegprep.functions.popfunc._rejection import ( - component_rejection_flags, - set_component_rejection_flag, -) -from eegprep.functions.sigprocfunc.topoplot import topoplot -from eegprep.plugins.ICLabel._prop_numerics import ( - DipfitData, - ExtendedPropertyData, - build_extended_property_data, - component_count, - component_rejection_status, -) -from eegprep.plugins.dipfit._mri import dipfit_mri_slices, load_standard_mri_volume - - -_DASHBOARD_SIZE = (12.0, 7.0) -_SCROLL_SECONDS = 5.0 -_EVENT_COLORS = ( - "#1f77b4", - "#2ca02c", - "#9467bd", - "#17becf", - "#ff7f0e", - "#8c564b", - "#e377c2", - "#7f7f7f", -) -_DIPFIT_COLORS = ("#00cc00", "#d336d3", "#e0c21a") -_REJECT_COLOR = "#ff9999" -_ACCEPT_COLOR = "#bfffbf" -_CONTROL_COLOR = "#e6e6e6" -_DISABLED_CONTROL_COLOR = "#d0d0d0" - - -@dataclass(frozen=True) -class ActivityTraceData: - """Inline scroll-plot trace using EEGLAB-compatible event coordinates.""" - - x_values: np.ndarray - y_values: np.ndarray - times_ms: np.ndarray - pnts: int - epoched: bool - - -def build_navigable_dashboard( - EEG: dict[str, Any], - typecomp: int, - indices: list[int], - winhandle: Any, - spec_opt: Any, - erp_opt: Any, - scroll_event: int | bool, - classifier_name: str, - *, - fig: Any, - show_activity: bool, - reject_callback: Any | None, -) -> Any: - """Build the navigable Matplotlib property dashboard.""" - figure = fig if fig is not None else plt.figure(figsize=_DASHBOARD_SIZE) - state = { - "EEG": EEG, - "typecomp": int(typecomp), - "indices": tuple(indices), - "position": 0, - "spec_opt": spec_opt, - "erp_opt": erp_opt, - "scroll_event": int(bool(scroll_event)), - "classifier_name": classifier_name, - "show_activity": bool(show_activity), - "winhandle": winhandle, - "reject_callback": reject_callback, - "rejection_pending": _initial_rejection_state(EEG, int(typecomp), indices), - } - figure.eegprep_dashboard_state = state - _render_dashboard(figure) - return figure - - -def _render_dashboard(figure: Any) -> None: - state = figure.eegprep_dashboard_state - indices = state["indices"] - position = int(state["position"]) - index = int(indices[position]) - dashboard = build_extended_property_data( - state["EEG"], - state["typecomp"], - index, - spec_opt=state["spec_opt"], - erp_opt=state["erp_opt"], - classifier_name=state["classifier_name"], - ) - figure.clf() - figure.set_size_inches(*_DASHBOARD_SIZE, forward=True) - figure.patch.set_facecolor((0.93, 0.96, 1.0)) - _set_window_title(figure, dashboard.figure_title) - has_rejection_controls = _has_rejection_controls(state) - bottom = ( - 0.17 - if has_rejection_controls and len(indices) > 1 - else 0.125 - if has_rejection_controls - else 0.12 - if len(indices) > 1 - else 0.075 - ) - if dashboard.dipfit is None: - grid = figure.add_gridspec( - 2, - 4, - left=0.055, - right=0.97, - top=0.88, - bottom=bottom, - wspace=0.65, - hspace=0.55, - width_ratios=(1.2, 0.95, 1.25, 1.25), - height_ratios=(0.9, 1.2), - ) - topo_ax = figure.add_subplot(grid[0, 0]) - if dashboard.classifier is None: - class_ax = None - activity_ax = figure.add_subplot(grid[0, 1:]) - else: - class_ax = figure.add_subplot(grid[0, 1]) - activity_ax = figure.add_subplot(grid[0, 2:]) - image_ax = figure.add_subplot(grid[1, :2]) - dipfit_axes = [] - spectrum_ax = figure.add_subplot(grid[1, 2:]) - else: - grid = figure.add_gridspec( - 2, - 5, - left=0.055, - right=0.97, - top=0.88, - bottom=bottom, - wspace=0.58, - hspace=0.55, - width_ratios=(1.15, 0.9, 0.85, 1.25, 1.25), - height_ratios=(0.9, 1.2), - ) - topo_ax = figure.add_subplot(grid[0, 0]) - if dashboard.classifier is None: - class_ax = None - activity_ax = figure.add_subplot(grid[0, 1:]) - else: - class_ax = figure.add_subplot(grid[0, 1]) - activity_ax = figure.add_subplot(grid[0, 2:]) - image_ax = figure.add_subplot(grid[1, :2]) - dipfit_grid = grid[1, 2].subgridspec(3, 1, hspace=0.04) - dipfit_axes = [figure.add_subplot(dipfit_grid[row, 0]) for row in range(3)] - spectrum_ax = figure.add_subplot(grid[1, 3:]) - - _plot_topography(topo_ax, dashboard) - if class_ax is not None: - _plot_classifier(class_ax, dashboard) - events = state["EEG"].get("event", []) if bool(state["scroll_event"]) else [] - _plot_activity(activity_ax, dashboard, events) - _plot_activity_image(image_ax, dashboard) - if dipfit_axes: - _plot_dipfit(dipfit_axes, dashboard.dipfit) - _plot_spectrum(spectrum_ax, dashboard) - figure.suptitle(dashboard.figure_title, fontsize=14, fontweight="bold") - figure.eegprep_dashboard_data = dashboard - figure.eegprep_activity_view = property_activity_browser( - state["EEG"], - dashboard.typecomp, - dashboard.index, - scroll_event=state["scroll_event"], - show=state["show_activity"], - ) - if len(indices) > 1: - _add_navigation_controls( - figure, - bottom=0.075 if has_rejection_controls else 0.025, - count_y=0.1 if has_rejection_controls else 0.092, - ) - else: - figure.eegprep_dashboard_navigation = {} - figure.eegprep_dashboard_navigation_buttons = () - if has_rejection_controls: - _add_rejection_controls(figure, dashboard) - else: - figure.eegprep_dashboard_rejection = {} - figure.eegprep_dashboard_rejection_buttons = {} - figure.eegprep_dashboard_rejection_button_list = () - buttons = [] - buttons.extend(getattr(figure, "eegprep_dashboard_navigation_buttons", ())) - buttons.extend(getattr(figure, "eegprep_dashboard_rejection_button_list", ())) - figure.eegprep_dashboard_buttons = tuple(buttons) - figure.canvas.draw_idle() - - -def _plot_topography(axis: Any, dashboard: ExtendedPropertyData) -> None: - if dashboard.typecomp: - topoplot( - dashboard.topography_values, - dashboard.topography_chanlocs, - axes=axis, - style="blank", - electrodes="off", - ) - else: - topoplot( - dashboard.topography_values, - dashboard.topography_chanlocs, - axes=axis, - electrodes="on", - colorbar=False, - ) - axis.set_title(dashboard.topography_title, fontsize=12, fontweight="normal") - if dashboard.pvaf is not None: - axis.text( - 0.5, - -0.13, - f"{{% scalp data var. accounted for}}: {dashboard.pvaf:.1f}%", - transform=axis.transAxes, - ha="center", - va="top", - fontsize=9, - ) - - -def _plot_classifier(axis: Any, dashboard: ExtendedPropertyData) -> None: - assert dashboard.classifier is not None - assert dashboard.class_probabilities is not None - labels = list(reversed(dashboard.classifier.classes)) - probabilities = np.asarray(dashboard.class_probabilities, dtype=float)[::-1] - y_values = np.arange(len(labels)) - axis.barh(y_values, probabilities, color="#4c78a8") - axis.set_yticks(y_values, labels) - axis.set_xlim(0.0, 1.0) - axis.set_xticks([0.0, 0.5, 1.0]) - axis.grid(axis="x", alpha=0.3) - axis.set_xlabel("Probability") - axis.set_title(dashboard.classifier.name, fontsize=12, fontweight="normal") - for y_value, probability in zip(y_values, probabilities): - axis.text(0.5, y_value, f"{probability * 100:.1f}%", ha="center", va="center", fontsize=8) - - -def _plot_activity(axis: Any, dashboard: ExtendedPropertyData, events: Any) -> None: - trace = _activity_trace(dashboard) - axis.plot(trace.x_values, trace.y_values, color="black", linewidth=0.85) - axis.axhline(0.0, color="0.75", linewidth=0.6) - _plot_epoch_markers(axis, trace) - _plot_event_markers(axis, trace, events) - axis.set_title(dashboard.activity_title, fontsize=12, fontweight="normal") - axis.set_xlabel("Time (ms)") - axis.set_ylabel("uV") - axis.grid(True, alpha=0.2) - _format_scrollplot_axis(axis, trace) - - -def _plot_activity_image(axis: Any, dashboard: ExtendedPropertyData) -> None: - image = np.asarray(dashboard.image_data, dtype=float) - handle = axis.imshow( - image, - aspect="auto", - origin="lower", - extent=dashboard.image_extent, - cmap="RdBu_r", - ) - axis.set_title(dashboard.image_title, fontsize=12, fontweight="normal") - axis.set_xlabel("Time (ms)" if dashboard.activity.shape[2] > 1 else "Data") - axis.set_ylabel("Epoch" if dashboard.activity.shape[2] > 1 else "Data") - plt.colorbar(handle, ax=axis, fraction=0.046, pad=0.035) - - -def _plot_spectrum(axis: Any, dashboard: ExtendedPropertyData) -> None: - axis.plot(dashboard.spectrum_freqs, dashboard.spectrum_power, color="black", linewidth=1.0) - axis.set_title(dashboard.spectrum_title, fontsize=12, fontweight="normal") - axis.set_xlabel("Frequency (Hz)") - axis.set_ylabel("Power 10*log10(uV^2/Hz)") - axis.grid(True, alpha=0.25) - finite = np.isfinite(dashboard.spectrum_power) - if np.any(finite): - finite_values = dashboard.spectrum_power[finite] - low = float(np.min(finite_values)) - high = float(np.max(finite_values)) - if high == low: - padding = 1.0 if low == 0.0 else abs(low) * 0.05 - low -= padding - high += padding - axis.set_ylim(low, high) - - -def _plot_dipfit(axes: list[Any], dipfit: DipfitData | None) -> None: - assert dipfit is not None - volume = load_standard_mri_volume() - for axis, mri_slice in zip(axes, dipfit_mri_slices(volume, dipfit.positions)): - axis.set_facecolor("black") - axis.imshow(mri_slice.image, cmap="gray", origin="lower", extent=mri_slice.extent, interpolation="nearest") - axis.set_aspect("equal", adjustable="box") - axis.set_xticks([]) - axis.set_yticks([]) - for spine in axis.spines.values(): - spine.set_visible(False) - _plot_dipfit_points(axis, dipfit, mri_slice.x_axis, mri_slice.y_axis) - axes[0].set_title("Dipole Position", fontsize=12, fontweight="normal", pad=7) - _plot_dipfit_values(axes[-1], dipfit) - - -def _plot_dipfit_points(axis: Any, dipfit: DipfitData, x_index: int, y_index: int) -> None: - for row, position in enumerate(dipfit.positions): - color = _DIPFIT_COLORS[row % len(_DIPFIT_COLORS)] - x_value = float(position[x_index]) - y_value = float(position[y_index]) - axis.plot( - x_value, - y_value, - marker="o", - markersize=5.5, - color=color, - markeredgecolor="white", - markeredgewidth=0.45, - ) - if dipfit.moments is not None and row < dipfit.moments.shape[0]: - _plot_dipfit_moment(axis, x_value, y_value, dipfit.moments[row], x_index, y_index, color) - - -def _plot_dipfit_moment( - axis: Any, - x_value: float, - y_value: float, - moment: np.ndarray, - x_index: int, - y_index: int, - color: str, -) -> None: - dx = float(moment[x_index]) - dy = float(moment[y_index]) - norm = float(np.hypot(dx, dy)) - if not np.isfinite(norm) or norm <= 0.0: - return - axis.arrow( - x_value, - y_value, - dx / norm * 18.0, - dy / norm * 18.0, - color=color, - width=0.8, - head_width=5.0, - length_includes_head=True, - alpha=0.9, - ) - - -def _plot_dipfit_values(axis: Any, dipfit: DipfitData) -> None: - lines = [] - if dipfit.rv_percent is not None: - lines.append(f"RV: {dipfit.rv_percent:.1f}%") - if dipfit.dmr is not None: - lines.append(f"DMR: {dipfit.dmr:.1f}") - if lines: - axis.text( - 0.5, - -0.02, - "\n".join(lines), - transform=axis.transAxes, - color="black", - fontsize=8, - ha="center", - va="top", - ) - - -def _add_navigation_controls(figure: Any, *, bottom: float, count_y: float) -> None: - previous_axis = figure.add_axes((0.37, bottom, 0.105, 0.05)) - next_axis = figure.add_axes((0.525, bottom, 0.105, 0.05)) - previous_button = Button(previous_axis, "Previous") - next_button = Button(next_axis, "Next") - - def previous(_event: Any = None) -> None: - _navigate_dashboard(figure, -1) - - def next_(_event: Any = None) -> None: - _navigate_dashboard(figure, 1) - - previous_button.on_clicked(previous) - next_button.on_clicked(next_) - figure.eegprep_dashboard_navigation_buttons = (previous_button, next_button) - figure.eegprep_dashboard_navigation = {"previous": previous, "next": next_} - state = figure.eegprep_dashboard_state - figure.text( - 0.5, - count_y, - f"{int(state['position']) + 1} / {len(state['indices'])}", - ha="center", - va="center", - fontsize=9, - ) - - -def _add_rejection_controls(figure: Any, dashboard: ExtendedPropertyData) -> None: - state = figure.eegprep_dashboard_state - index = int(dashboard.index) - rejected = _pending_rejection_status(state, index) - cancel_button = Button(figure.add_axes((0.2, 0.015, 0.1, 0.045)), "Cancel", color=_CONTROL_COLOR) - values_button = Button(figure.add_axes((0.325, 0.015, 0.1, 0.045)), "Values", color=_CONTROL_COLOR) - status_button = Button( - figure.add_axes((0.45, 0.015, 0.1, 0.045)), - _rejection_label(rejected), - color=_rejection_color(rejected), - hovercolor=_rejection_color(rejected), - ) - help_button = Button(figure.add_axes((0.575, 0.015, 0.1, 0.045)), "HELP", color=_CONTROL_COLOR) - ok_button = Button(figure.add_axes((0.7, 0.015, 0.1, 0.045)), "OK", color=_CONTROL_COLOR) - - if not _component_values_available(state["EEG"]): - values_button.set_active(False) - values_button.ax.set_facecolor(_DISABLED_CONTROL_COLOR) - - def cancel(_event: Any = None) -> None: - plt.close(figure) - - def values(_event: Any = None) -> None: - if _component_values_available(state["EEG"]): - _show_component_values(state["EEG"], index, _pending_rejection_status(state, index)) - - def toggle(_event: Any = None) -> None: - state["rejection_pending"][index] = not _pending_rejection_status(state, index) - _style_rejection_status_button(status_button, state["rejection_pending"][index]) - figure.canvas.draw_idle() - - def help_(_event: Any = None) -> None: - pophelp("pop_prop_extended") - - def ok(_event: Any = None) -> None: - _commit_rejection_state(figure) - plt.close(figure) - - cancel_button.on_clicked(cancel) - values_button.on_clicked(values) - status_button.on_clicked(toggle) - help_button.on_clicked(help_) - ok_button.on_clicked(ok) - figure.eegprep_dashboard_rejection_button_list = ( - cancel_button, - values_button, - status_button, - help_button, - ok_button, - ) - figure.eegprep_dashboard_rejection_buttons = { - "cancel": cancel_button, - "values": values_button, - "status": status_button, - "help": help_button, - "ok": ok_button, - } - figure.eegprep_dashboard_rejection = { - "cancel": cancel, - "values": values, - "toggle": toggle, - "help": help_, - "ok": ok, - "pending": state["rejection_pending"], - } - - -def _navigate_dashboard(figure: Any, step: int) -> None: - state = figure.eegprep_dashboard_state - state["position"] = (int(state["position"]) + int(step)) % len(state["indices"]) - _render_dashboard(figure) - - -def _initial_rejection_state(EEG: dict[str, Any], typecomp: int, indices: list[int]) -> dict[int, bool]: - if int(typecomp): - return {} - total = component_count(EEG) - flags = component_rejection_flags(EEG, total, create=False) - return {int(index): bool(flags[int(index) - 1]) for index in indices} - - -def _has_rejection_controls(state: dict[str, Any]) -> bool: - return int(state["typecomp"]) == 0 and bool(state["indices"]) - - -def _pending_rejection_status(state: dict[str, Any], component_index: int) -> bool: - pending = state["rejection_pending"] - index = int(component_index) - if index not in pending: - pending[index] = component_rejection_status(state["EEG"], index) - return bool(pending[index]) - - -def _rejection_label(rejected: bool) -> str: - return "REJECT" if rejected else "ACCEPT" - - -def _rejection_color(rejected: bool) -> str: - return _REJECT_COLOR if rejected else _ACCEPT_COLOR - - -def _style_rejection_status_button(button: Button, rejected: bool) -> None: - button.label.set_text(_rejection_label(rejected)) - button.ax.set_facecolor(_rejection_color(rejected)) - button.color = _rejection_color(rejected) - button.hovercolor = _rejection_color(rejected) - - -def _commit_rejection_state(figure: Any) -> None: - state = figure.eegprep_dashboard_state - if not _has_rejection_controls(state): - return - EEG = state["EEG"] - total = component_count(EEG) - committed: dict[int, bool] = {} - for index in sorted(state["rejection_pending"]): - rejected = bool(state["rejection_pending"][index]) - set_component_rejection_flag(EEG, index, rejected, total) - _style_rejection_winhandle(state["winhandle"], index, rejected) - committed[int(index)] = rejected - callback = state.get("reject_callback") - if callback is not None: - callback(EEG, dict(committed)) - - -def _style_rejection_winhandle(winhandle: Any, component_index: int, rejected: bool) -> None: - if _is_empty_winhandle(winhandle): - return - handle = winhandle - if isinstance(winhandle, Mapping): - handle = winhandle.get(int(component_index)) - if isinstance(handle, Button): - handle.ax.set_facecolor(_rejection_color(rejected)) - - -def _is_empty_winhandle(winhandle: Any) -> bool: - if winhandle is None: - return True - if isinstance(winhandle, (int, float, np.integer, np.floating)): - return bool(winhandle == 0) or bool(np.isnan(float(winhandle))) - return False - - -def _component_values_available(EEG: dict[str, Any]) -> bool: - stats = EEG.get("stats") - if not isinstance(stats, dict): - return False - return np.asarray(stats.get("compenta", [])).size > 0 - - -def _show_component_values(EEG: dict[str, Any], component_index: int, rejected: bool) -> None: - values = _component_value_lines(EEG, component_index, rejected) - figure = plt.figure(figsize=(3.4, 3.4)) - manager = getattr(figure.canvas, "manager", None) - if manager is not None: - manager.set_window_title("Statistics of the component") - axis = figure.add_subplot(1, 1, 1) - axis.axis("off") - axis.text(0.04, 0.96, "\n".join(values), va="top", ha="left", fontsize=9, family="monospace") - close_button = Button(figure.add_axes((0.375, 0.03, 0.25, 0.08)), "Close", color=_CONTROL_COLOR) - close_button.on_clicked(lambda _event=None: plt.close(figure)) - setattr(figure, "eegprep_component_values_button", close_button) - figure.tight_layout(rect=(0, 0.12, 1, 1)) - figure.canvas.draw_idle() - - -def _component_value_lines(EEG: dict[str, Any], component_index: int, rejected: bool) -> list[str]: - raw_stats = EEG.get("stats") - stats = raw_stats if isinstance(raw_stats, dict) else {} - raw_reject = EEG.get("reject") - reject = raw_reject if isinstance(raw_reject, dict) else {} - index = int(component_index) - return [ - "(", - f"Entropy of component activity {_indexed_stat(stats.get('compenta'), index):>8}", - f"> Rejection threshold {_scalar_stat(reject.get('threshentropy')):>8}", - "", - " AND ----", - "", - f"Kurtosis of component activity {_indexed_stat(stats.get('compkurta'), index):>8}", - f"> Rejection threshold {_scalar_stat(reject.get('threshkurtact')):>8}", - "", - ") OR ----", - "", - f"Kurtosis distribution {_indexed_stat(stats.get('compkurtdist'), index):>8}", - f"> Rejection threshold {_scalar_stat(reject.get('threshkurtdist')):>8}", - "", - f"Current thresholds suggest to {_rejection_label(rejected)} the component", - "", - "After manually accepting/rejecting the component, recalibrate", - "thresholds before applying automatic rejection to other datasets.", - ] - - -def _indexed_stat(values: Any, component_index: int) -> str: - vector = np.asarray(values, dtype=float).ravel() - index = int(component_index) - 1 - if index < 0 or index >= vector.size or not np.isfinite(vector[index]): - return "----" - return f"{float(vector[index]):2.2f}" - - -def _scalar_stat(value: Any) -> str: - vector = np.asarray(value, dtype=float).ravel() - if vector.size == 0 or not np.isfinite(vector[0]): - return "----" - return f"{float(vector[0]):2.2f}" - - -def _activity_trace(dashboard: ExtendedPropertyData) -> ActivityTraceData: - trace = np.asarray(dashboard.activity[0], dtype=float) - if trace.ndim == 1: - trace = trace[:, np.newaxis] - pnts = int(dashboard.times_ms.size) - srate = _srate_from_times(dashboard.times_ms) - window_samples = max(1, int(round(_SCROLL_SECONDS * srate))) - if trace.shape[1] == 1: - sample_count = min(pnts, window_samples) - return ActivityTraceData( - x_values=np.arange(1, sample_count + 1, dtype=float), - y_values=trace[:sample_count, 0], - times_ms=dashboard.times_ms, - pnts=pnts, - epoched=False, - ) - flat = trace.T.reshape(-1) - sample_count = min(flat.size, window_samples) - return ActivityTraceData( - x_values=np.arange(1, sample_count + 1, dtype=float), - y_values=flat[:sample_count], - times_ms=dashboard.times_ms, - pnts=pnts, - epoched=True, - ) - - -def _plot_event_markers( - axis: Any, - trace: ActivityTraceData, - events: Any, -) -> None: - event_items = _event_items(events) - if not event_items: - return - first = float(np.nanmin(trace.x_values)) - last = float(np.nanmax(trace.x_values)) - colors = _event_color_map(event_items) - for event in event_items: - if "latency" not in event: - continue - try: - latency = float(_event_scalar(event["latency"])) - except (TypeError, ValueError): - continue - x_value = latency - if first <= x_value <= last: - label = str(_event_scalar(event.get("type", ""))) - axis.axvline(x_value, color=colors[label], linestyle="--", linewidth=0.75) - axis.text( - x_value, - 0.98, - label, - color=colors[label], - transform=axis.get_xaxis_transform(), - rotation=45, - fontsize=8, - ) - - -def _plot_epoch_markers(axis: Any, trace: ActivityTraceData) -> None: - if not trace.epoched: - return - first = float(np.nanmin(trace.x_values)) - last = float(np.nanmax(trace.x_values)) - start = int(np.ceil(first / trace.pnts) * trace.pnts) - for boundary in range(start, int(np.floor(last / trace.pnts) * trace.pnts) + 1, trace.pnts): - if boundary < first or boundary > last: - continue - axis.axvline(float(boundary), color="red", linestyle="-", linewidth=0.75) - axis.text( - float(boundary), - 0.98, - f"epoch {boundary // trace.pnts}", - color="red", - transform=axis.get_xaxis_transform(), - rotation=45, - fontsize=8, - ) - - -def _format_scrollplot_axis(axis: Any, trace: ActivityTraceData) -> None: - first = float(trace.x_values[0]) - last = float(trace.x_values[-1]) - axis.set_xlim(first, last) - tick_count = min(6, trace.x_values.size) - ticks = np.unique(np.linspace(first, last, tick_count).round().astype(int)) - labels = [] - for tick in ticks: - sample = (int(tick) - 1) % trace.pnts if trace.epoched else min(max(int(tick) - 1, 0), trace.pnts - 1) - labels.append(f"{trace.times_ms[sample]:g}") - axis.set_xticks(ticks) - axis.set_xticklabels(labels) - - -def _event_items(events: Any) -> list[dict[str, Any]]: - if events is None: - return [] - if isinstance(events, dict): - if "latency" not in events: - return [] - latencies = np.asarray(events["latency"], dtype=object).ravel() - types = np.asarray(events.get("type", [""] * latencies.size), dtype=object).ravel() - epochs = np.asarray(events.get("epoch", [None] * latencies.size), dtype=object).ravel() - event_items = [] - for index, latency in enumerate(latencies): - event = { - "type": _event_scalar(types[min(index, types.size - 1)]), - "latency": _event_scalar(latency), - } - epoch = _event_scalar(epochs[min(index, epochs.size - 1)]) - if epoch is not None: - event["epoch"] = epoch - event_items.append(event) - return event_items - if isinstance(events, np.ndarray): - events = events.tolist() - try: - event_values = list(events) - except TypeError: - return [] - return [event for event in event_values if isinstance(event, dict)] - - -def _event_scalar(value: Any) -> Any: - array = np.asarray(value) - if array.shape == (): - return array.item() - if array.size == 1: - item = array.ravel()[0] - return item.item() if hasattr(item, "item") else item - return value - - -def _event_color_map(event_items: list[dict[str, Any]]) -> dict[str, str]: - labels: list[str] = [] - for event in event_items: - label = str(_event_scalar(event.get("type", ""))) - if label not in labels: - labels.append(label) - return {label: _EVENT_COLORS[index % len(_EVENT_COLORS)] for index, label in enumerate(labels)} - - -def _srate_from_times(times_ms: np.ndarray) -> float: - if times_ms.size < 2: - return 1.0 - interval_ms = float(np.nanmedian(np.diff(times_ms))) - if not np.isfinite(interval_ms) or interval_ms <= 0: - return 1.0 - return 1000.0 / interval_ms - - -def _set_window_title(figure: Any, title: str) -> None: - figure.set_label(title) - manager = getattr(getattr(figure, "canvas", None), "manager", None) - if manager is not None and hasattr(manager, "set_window_title"): - manager.set_window_title(title) - - -__all__ = ["ActivityTraceData", "build_navigable_dashboard"] diff --git a/src/eegprep/plugins/ICLabel/_prop_numerics.py b/src/eegprep/plugins/ICLabel/_prop_numerics.py deleted file mode 100644 index 619419c8..00000000 --- a/src/eegprep/plugins/ICLabel/_prop_numerics.py +++ /dev/null @@ -1,474 +0,0 @@ -"""Data assembly and numerical helpers for ICLabel property dashboards.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -import numpy as np - -from eegprep.functions.popfunc._plot_utils import ( - channel_labels, - component_activations, - component_channel_indices, - component_map_data, - eeg_epoch_data, - eeg_times_ms, - numeric_vector, - parse_plot_options_text, -) -from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._rejection import component_rejection_flags, one_based_indices -from eegprep.functions.sigprocfunc.spectopo import compute_spectra -from eegprep.plugins.dipfit._utils import normalize_model_list - - -DEFAULT_ICLABEL_CLASSES = ("Brain", "Muscle", "Eye", "Heart", "Line Noise", "Channel Noise", "Other") - - -@dataclass(frozen=True) -class ClassifierData: - """Normalized component-classifier output from ``EEG.etc.ic_classification``.""" - - name: str - classes: tuple[str, ...] - probabilities: np.ndarray - - -@dataclass(frozen=True) -class DipfitData: - """Normalized localized DIPFIT model for one ICA component.""" - - positions: np.ndarray - moments: np.ndarray | None - rv_percent: float | None - dmr: float | None - coordformat: str - - -@dataclass(frozen=True) -class ExtendedPropertyData: - """Data assembled for one extended channel/component property dashboard.""" - - typecomp: int - index: int - label: str - figure_title: str - topography_title: str - topography_values: Any - topography_chanlocs: list[dict[str, Any]] - activity: np.ndarray - times_ms: np.ndarray - activity_title: str - image_data: np.ndarray - image_extent: tuple[float, float, float, float] - image_title: str - spectrum_freqs: np.ndarray - spectrum_power: np.ndarray - spectrum_title: str - classifier: ClassifierData | None - class_probabilities: np.ndarray | None - pvaf: float | None - dipfit: DipfitData | None - rejected: bool | None - - -def classifier_names(EEG: dict[str, Any]) -> list[str]: - """Return classifier field names available under ``EEG.etc.ic_classification``.""" - etc = EEG.get("etc") or {} - if not isinstance(etc, dict): - return [] - classifications = etc.get("ic_classification") or {} - if not isinstance(classifications, dict): - return [] - return [str(name) for name in classifications if str(name)] - - -def classifier_default_index(classifiers: list[str]) -> int: - """Return EEGLAB popup index for the default component classifier.""" - for index, name in enumerate(classifiers, start=1): - if name.lower() == "iclabel": - return index - return 1 - - -def classifier_name_from_gui(EEG: dict[str, Any], value: Any) -> str: - """Resolve a GUI popup value to a classifier field name.""" - classifiers = classifier_names(EEG) - if not classifiers: - return "" - if isinstance(value, str): - for classifier in classifiers: - if classifier.lower() == value.lower(): - return classifier - return classifiers[classifier_default_index(classifiers) - 1] - try: - index = int(value) - 1 - except (TypeError, ValueError): - index = classifier_default_index(classifiers) - 1 - if 0 <= index < len(classifiers): - return classifiers[index] - return classifiers[classifier_default_index(classifiers) - 1] - - -def resolve_classifier_data( - EEG: dict[str, Any], - classifier_name: str = "", - *, - component_total: int | None = None, - require: bool = False, -) -> ClassifierData | None: - """Return normalized classifier data or ``None`` when no classifier is available.""" - classifiers = classifier_names(EEG) - if not classifiers: - if require: - raise ValueError("No component classifier data found in EEG.etc.ic_classification") - return None - resolved_name = _resolve_classifier_name(classifiers, classifier_name) - record = (EEG.get("etc") or {})["ic_classification"][resolved_name] - if not isinstance(record, dict): - raise ValueError(f"Classifier {resolved_name!r} must be stored as a dictionary") - probabilities = np.asarray(record.get("classifications", []), dtype=float) - if probabilities.ndim != 2 or probabilities.size == 0: - raise ValueError(f"Classifier {resolved_name!r} is missing a 2-D classifications matrix") - if component_total is not None and probabilities.shape[0] != int(component_total): - raise ValueError( - f"Classifier {resolved_name!r} has {probabilities.shape[0]} rows for {component_total} ICA components" - ) - classes = _classifier_classes(record, resolved_name, probabilities.shape[1]) - return ClassifierData(resolved_name, classes, probabilities) - - -def resolve_dipfit_data(EEG: dict[str, Any], component_index: int) -> DipfitData | None: - """Return normalized DIPFIT model data for a 1-based component index.""" - models = normalize_model_list(EEG) - index = int(component_index) - if index < 1: - raise ValueError("component index must be 1-based") - if index > len(models): - return None - model = models[index - 1] - raw_positions = np.asarray(model.get("posxyz", []), dtype=float) - if raw_positions.size == 0: - return None - positions = _dipfit_matrix(raw_positions, "posxyz", index) - if not np.all(np.isfinite(positions)): - raise ValueError(f"DIPFIT model for component {index} contains non-finite posxyz values") - moments = _dipfit_moments(model.get("momxyz", []), positions.shape[0], index) - rv = _finite_float(model.get("rv")) - coordformat = "" - dipfit = EEG.get("dipfit") - if isinstance(dipfit, dict): - coordformat = str(dipfit.get("coordformat") or "") - return DipfitData( - positions=positions, - moments=moments, - rv_percent=None if rv is None else rv * 100.0, - dmr=_dipole_moment_ratio(moments), - coordformat=coordformat, - ) - - -def selected_property_indices( - EEG: dict[str, Any], - typecomp: int | bool, - values: Any, - *, - default_all: bool = True, -) -> list[int]: - """Normalize EEGLAB-facing channel/component selections to 1-based indices.""" - limit = int(EEG.get("nbchan", 0) or 0) if int(bool(typecomp)) else component_count(EEG) - return one_based_indices(values, limit=limit, default_all=default_all) - - -def component_count(EEG: dict[str, Any]) -> int: - """Return the number of available ICA components.""" - icaact = EEG.get("icaact") - if icaact is not None and np.asarray(icaact).size: - values = np.asarray(icaact) - if values.ndim >= 2: - return int(values.shape[0]) - weights = np.asarray(EEG.get("icaweights", [])) - if weights.ndim == 2 and weights.size: - return int(weights.shape[0]) - winv = np.asarray(EEG.get("icawinv", [])) - if winv.ndim == 2 and winv.size: - return int(winv.shape[1]) - return 0 - - -def has_component_classifier(EEG: dict[str, Any], classifier_name: str = "") -> bool: - """Return whether usable classifier data are available for the current ICA.""" - try: - return ( - resolve_classifier_data(EEG, classifier_name, component_total=component_count(EEG), require=False) - is not None - ) - except ValueError as exc: - if "missing a 2-D classifications matrix" in str(exc): - return False - raise - - -def component_rejection_status( - EEG: dict[str, Any], - component_index: int, - *, - component_total: int | None = None, -) -> bool: - """Return the current ``EEG.reject.gcompreject`` status for one component.""" - total = component_count(EEG) if component_total is None else int(component_total) - index = int(component_index) - if index < 1 or index > total: - raise ValueError("component index is outside available ICA components") - return bool(component_rejection_flags(EEG, total, create=False)[index - 1]) - - -def build_extended_property_data( - EEG: dict[str, Any], - typecomp: int | bool, - index: int, - *, - spec_opt: Any = None, - erp_opt: Any = None, - classifier_name: str = "", -) -> ExtendedPropertyData: - """Assemble the dashboard data for one EEGLAB-facing channel/component index.""" - del erp_opt - typecomp = int(bool(typecomp)) - data = eeg_epoch_data(EEG) - times_ms = eeg_times_ms(EEG) - if typecomp: - return _channel_dashboard_data(EEG, data, times_ms, int(index), spec_opt) - return _component_dashboard_data(EEG, data, times_ms, int(index), spec_opt, classifier_name) - - -def _channel_dashboard_data( - EEG: dict[str, Any], - data: np.ndarray, - times_ms: np.ndarray, - index: int, - spec_opt: Any, -) -> ExtendedPropertyData: - labels = channel_labels(EEG) - if index < 1 or index > data.shape[0]: - raise ValueError("channel index is outside available channels") - label = labels[index - 1] if index - 1 < len(labels) else str(index) - activity = np.array(data[index - 1 : index], dtype=float, copy=True) - spectrum_freqs, spectrum_power = _spectrum(activity, EEG, spec_opt) - image_data, image_extent, image_title = _activity_image(activity, times_ms, f"Epoched Channel {label} Activity") - return ExtendedPropertyData( - typecomp=1, - index=index, - label=label, - figure_title=f"Channel {label} - pop_prop_extended()", - topography_title=f"Channel {label}", - topography_values=index, - topography_chanlocs=chanlocs_as_list(EEG.get("chanlocs")), - activity=activity, - times_ms=times_ms, - activity_title="Channel Time Series", - image_data=image_data, - image_extent=image_extent, - image_title=image_title, - spectrum_freqs=spectrum_freqs, - spectrum_power=spectrum_power, - spectrum_title="Channel Activity Power Spectrum", - classifier=None, - class_probabilities=None, - pvaf=None, - dipfit=None, - rejected=None, - ) - - -def _component_dashboard_data( - EEG: dict[str, Any], - data: np.ndarray, - times_ms: np.ndarray, - index: int, - spec_opt: Any, - classifier_name: str, -) -> ExtendedPropertyData: - activity_all = component_activations(EEG) - maps, map_chanlocs = component_map_data(EEG) - if index < 1 or index > activity_all.shape[0]: - raise ValueError("component index is outside available ICA components") - activity = np.array(activity_all[index - 1 : index], dtype=float, copy=True) - classifier = resolve_classifier_data(EEG, classifier_name, component_total=activity_all.shape[0], require=False) - probabilities = ( - None if classifier is None else np.array(classifier.probabilities[index - 1], dtype=float, copy=True) - ) - spectrum_freqs, spectrum_power = _spectrum(activity, EEG, spec_opt) - image_data, image_extent, image_title = _activity_image(activity, times_ms, f"Epoched IC{index} Activity") - return ExtendedPropertyData( - typecomp=0, - index=index, - label=f"IC{index}", - figure_title=f"IC{index} - pop_prop_extended()", - topography_title=f"IC{index}", - topography_values=maps[:, index - 1], - topography_chanlocs=map_chanlocs, - activity=activity, - times_ms=times_ms, - activity_title=f"Scrolling IC{index} Activity", - image_data=image_data, - image_extent=image_extent, - image_title=image_title, - spectrum_freqs=spectrum_freqs, - spectrum_power=spectrum_power, - spectrum_title=f"IC{index} Activity Power Spectrum", - classifier=classifier, - class_probabilities=probabilities, - pvaf=_component_pvaf(EEG, data, maps, activity, index), - dipfit=resolve_dipfit_data(EEG, index), - rejected=component_rejection_status(EEG, index, component_total=activity_all.shape[0]), - ) - - -def _spectrum(activity: np.ndarray, EEG: dict[str, Any], spec_opt: Any) -> tuple[np.ndarray, np.ndarray]: - options = parse_plot_options_text(spec_opt) - flat = np.asarray(activity, dtype=float).reshape(1, -1) - spectra, freqs, _std = compute_spectra( - flat, - int(EEG.get("pnts", flat.shape[1]) or flat.shape[1]), - float(EEG.get("srate", 1.0) or 1.0), - winsize=_first_int(options.get("winsize")), - overlap=_first_int(options.get("overlap")) or 0, - nfft=_first_int(options.get("nfft")), - ) - return freqs, spectra[0] - - -def _activity_image( - activity: np.ndarray, - times_ms: np.ndarray, - epoched_title: str, -) -> tuple[np.ndarray, tuple[float, float, float, float], str]: - trace = np.asarray(activity[0], dtype=float) - trace = trace - float(np.nanmean(trace)) - if trace.ndim == 1: - trace = trace[:, np.newaxis] - if trace.shape[1] > 1: - image = trace.T - extent = (float(times_ms[0]), float(times_ms[-1]), 1.0, float(trace.shape[1])) - return image, extent, epoched_title - flat = trace[:, 0] - line_count = min(200, max(1, int(np.floor(np.sqrt(flat.size))))) - frame_count = max(1, flat.size // line_count) - image = flat[: line_count * frame_count].reshape(line_count, frame_count) - extent = (0.0, float(frame_count - 1), 1.0, float(line_count)) - return image, extent, "Continuous Data" - - -def _dipfit_matrix(values: np.ndarray, field_name: str, component_index: int) -> np.ndarray: - matrix = values - if matrix.ndim == 1: - matrix = matrix.reshape(1, -1) - if matrix.ndim != 2 or matrix.shape[1] < 3: - raise ValueError( - f"DIPFIT model for component {component_index} must contain {field_name} rows with 3 coordinates" - ) - return np.array(matrix[:, :3], dtype=float, copy=True) - - -def _dipfit_moments(values: Any, position_count: int, component_index: int) -> np.ndarray | None: - raw_moments = np.asarray(values, dtype=float) - if raw_moments.size == 0: - return None - moments = _dipfit_matrix(raw_moments, "momxyz", component_index) - if moments.shape[0] != position_count: - raise ValueError(f"DIPFIT model for component {component_index} must have matching posxyz and momxyz rows") - if not np.all(np.isfinite(moments)): - raise ValueError(f"DIPFIT model for component {component_index} contains non-finite momxyz values") - return moments - - -def _dipole_moment_ratio(moments: np.ndarray | None) -> float | None: - if moments is None or moments.shape[0] != 2: - return None - norms = np.linalg.norm(moments, axis=1) - if not np.all(np.isfinite(norms)) or np.any(norms <= 0.0): - return None - ratio = float(norms[0] / norms[1]) - return ratio if ratio >= 1.0 else 1.0 / ratio - - -def _finite_float(value: Any) -> float | None: - try: - numeric = float(np.asarray(value).reshape(())) - except (TypeError, ValueError): - return None - return numeric if np.isfinite(numeric) else None - - -def _component_pvaf( - EEG: dict[str, Any], - data: np.ndarray, - maps: np.ndarray, - activity: np.ndarray, - index: int, -) -> float | None: - if maps.shape[0] == 0: - return None - icachansind = component_channel_indices(EEG, data.shape[0]) - if maps.shape[0] != icachansind.size: - return None - flat_data = data[icachansind, :, :].reshape(icachansind.size, -1) - component_trace = activity.reshape(1, -1) - projection = maps[:, index - 1 : index] @ component_trace - datavar = float(np.nanmean(np.nanvar(flat_data, axis=1))) - if not np.isfinite(datavar) or datavar <= 0: - return None - projvar = float(np.nanmean(np.nanvar(flat_data - projection, axis=1))) - if not np.isfinite(projvar): - return None - return 100.0 * (1.0 - projvar / datavar) - - -def _resolve_classifier_name(classifiers: list[str], classifier_name: str) -> str: - if classifier_name: - for name in classifiers: - if name.lower() == str(classifier_name).lower(): - return name - raise ValueError(f"Classifier {classifier_name!r} was not found in EEG.etc.ic_classification") - return classifiers[classifier_default_index(classifiers) - 1] - - -def _classifier_classes(record: dict[str, Any], classifier_name: str, class_count: int) -> tuple[str, ...]: - raw_classes = record.get("classes", []) - classes = [str(item) for item in np.asarray(raw_classes, dtype=object).ravel().tolist() if str(item)] - if not classes: - if classifier_name.lower() == "iclabel" and class_count == len(DEFAULT_ICLABEL_CLASSES): - return DEFAULT_ICLABEL_CLASSES - return tuple(f"Class {index}" for index in range(1, class_count + 1)) - if len(classes) != class_count: - raise ValueError( - f"Classifier {classifier_name!r} has {class_count} probability columns but {len(classes)} class names" - ) - return tuple(classes) - - -def _first_int(value: Any) -> int | None: - vector = numeric_vector(value) - if vector.size == 0: - return None - return int(vector[0]) - - -__all__ = [ - "DEFAULT_ICLABEL_CLASSES", - "ClassifierData", - "DipfitData", - "ExtendedPropertyData", - "build_extended_property_data", - "classifier_default_index", - "classifier_name_from_gui", - "classifier_names", - "component_count", - "component_rejection_status", - "has_component_classifier", - "resolve_classifier_data", - "resolve_dipfit_data", - "selected_property_indices", -] diff --git a/src/eegprep/plugins/ICLabel/eeg_autocorr.py b/src/eegprep/plugins/ICLabel/eeg_autocorr.py deleted file mode 100644 index ee8cd208..00000000 --- a/src/eegprep/plugins/ICLabel/eeg_autocorr.py +++ /dev/null @@ -1,58 +0,0 @@ -"""EEG autocorrelation functions.""" - -import numpy as np -from scipy.signal import resample_poly -from numpy.fft import fft, ifft - - -def eeg_autocorr(EEG, pct_data=None): - """Compute autocorrelation of ICA components. - - Parameters - ---------- - EEG : dict - EEG data structure with icaact - pct_data : float, optional - Percentage of data to use (default 100) - - Returns - ------- - ac : ndarray - Autocorrelation array - """ - if pct_data is None: - pct_data = 100 - - # convert EEG['icaact'] to single precision - EEG['icaact'] = EEG['icaact'].astype(np.float32) - - ncomp = EEG['icaact'].shape[0] - nfft = 2 ** np.ceil(np.log2(2 * EEG['pnts'] - 1)).astype(int) - - # Calculate autocorrelation - c = np.zeros((ncomp, nfft)) - for it in range(ncomp): - comp = EEG['icaact'][it, :].reshape(-1) - Xtmp = fft(comp, nfft) - Xtmp = Xtmp.astype( - np.complex64 - ) # matches MATLAB's single precision since python convert to double precision by default - X = np.abs(Xtmp) ** 2 - c[it, :] = np.real(ifft(X)) - - # Adjust the size of the autocorrelation to match sampling rate - if EEG['pnts'] < EEG['srate']: - ac = np.hstack([c[:, : EEG['pnts']], np.zeros((ncomp, EEG['srate'] - EEG['pnts'] + 1))]) - else: - ac = c[:, : int(EEG['srate']) + 1] - - # Normalize by the 0-tap of the autocorrelation - ac /= ac[:, [0]] - - # Resample to 1 second at 100 samples/sec - - # print the size of the second dim of ac - ac = resample_poly(ac.T, up=100, down=EEG['srate']).T - ac = ac[:, 1:] - - return ac diff --git a/src/eegprep/plugins/ICLabel/eeg_autocorr_fftw.py b/src/eegprep/plugins/ICLabel/eeg_autocorr_fftw.py deleted file mode 100644 index d00242cc..00000000 --- a/src/eegprep/plugins/ICLabel/eeg_autocorr_fftw.py +++ /dev/null @@ -1,61 +0,0 @@ -"""EEG autocorrelation computation using FFTW. - -This module provides functions for computing autocorrelation of EEG ICA components using -fast Fourier transform methods. -""" - -import numpy as np -from scipy.fft import fft, ifft, next_fast_len -from scipy.signal import resample_poly - - -def eeg_autocorr_fftw(EEG, pct_data=100): - """Compute autocorrelation of EEG ICA components using FFT. - - Parameters - ---------- - EEG : dict - EEG data structure with 'icaact', 'pnts', 'srate' fields. - pct_data : float, optional - Percentage of data to use. Default 100. - - Returns - ------- - ndarray - Autocorrelation array. - """ - # FFT length - nfft = next_fast_len(2 * EEG['pnts'] - 1) - - # Initialize autocorrelation array - ncomp = EEG['icaact'].shape[0] - ac = np.zeros((ncomp, nfft)) - - # Calculate autocorrelation using FFT - for it in range(EEG['icaact'].shape[0]): - # Apply FFT - X = fft(EEG['icaact'][it, :, :], n=nfft, axis=0) - # Compute the mean of the power spectrum - ac[it, :] = np.mean(np.abs(X) ** 2, axis=1) - - # Inverse FFT to get autocorrelation - ac = ifft(ac, axis=1) - - # make sure the data is in real - ac = np.real(ac) - - # Adjust the size of autocorrelation array - if EEG['pnts'] < EEG['srate']: - # ac = np.hstack( [ac[:, :EEG['pnts']], np.zeros((ncomp , EEG['srate'] - EEG['pnts'] + 1))]) - ac = np.concatenate((ac[:, : EEG['pnts']], np.zeros((ac.shape[0], EEG['srate'] - EEG['pnts'] + 1))), axis=1) - else: - ac = ac[:, : int(EEG['srate']) + 1] - - # Normalize by 0-lag autocorrelation - ac = ac / ac[:, 0][:, np.newaxis] - - # resample to 1 second at 100 samples/sec - ac = resample_poly(ac.T, up=100, down=EEG['srate']).T - ac = ac[:, 1:101] - - return ac diff --git a/src/eegprep/plugins/ICLabel/eeg_autocorr_welch.py b/src/eegprep/plugins/ICLabel/eeg_autocorr_welch.py deleted file mode 100644 index 6dfaaf15..00000000 --- a/src/eegprep/plugins/ICLabel/eeg_autocorr_welch.py +++ /dev/null @@ -1,84 +0,0 @@ -"""EEG autocorrelation computation using Welch method. - -This module provides functions for computing autocorrelation of EEG ICA components using -the Welch method for spectral estimation. -""" - -import numpy as np -from scipy.signal import resample_poly -import random -from numpy.fft import fft, ifft - - -def eeg_autocorr_welch(EEG, pct_data=100): - """Compute autocorrelation of EEG ICA components using Welch method. - - Parameters - ---------- - EEG : dict - EEG data structure with 'icaweights', 'icaact', 'pnts', 'srate' fields. - pct_data : float, optional - Percentage of data to use. Default 100. - - Returns - ------- - ndarray - Autocorrelation array. - """ - # clean input cutoff freq - if pct_data is None or pct_data == 0: - pct_data = 100 - - # setup constants - ncomp = EEG['icaweights'].shape[0] - n_points = min(EEG['pnts'], EEG['srate'] * 3) - nfft = 2 ** (int(np.log2(n_points * 2 - 1)) + 1) - cutoff = (EEG['pnts'] // n_points) * n_points - index = np.add.outer( - np.ceil(np.arange(0, cutoff - n_points + 1, n_points // 2)).astype(int), np.arange(n_points) - ).astype(int) - index = index.T - - # separate data segments - if pct_data != 100: - random.seed(0) - n_seg = index.shape[0] * EEG['trials'] - subset = random.sample(range(n_seg), int(np.ceil(n_seg * pct_data / 100))) - random.seed() # restore normal random behavior - temp = np.reshape(EEG['icaact'][:, index, :], (ncomp, *index.shape, EEG['trials'])) - segments = temp[:, :, subset] - else: - segments = np.reshape(EEG['icaact'][:, index, :], (ncomp, *index.shape, EEG['trials'])) - - # calc autocorrelation - ac = np.zeros((ncomp, nfft)) - for it in range(ncomp): - fftpow = np.mean(np.abs(fft(segments[it, :, :], nfft, axis=0)) ** 2, axis=1) - ac[it, :] = np.real(ifft(fftpow, axis=0)).T - - # normalizefft - if EEG['pnts'] < EEG['srate']: - ac = np.concatenate( - [ - ac[:, : EEG['pnts']] / (ac[:, 0][:, np.newaxis] * np.arange(n_points, 0, -1) / n_points), - np.zeros((ncomp, int(EEG['srate']) - n_points + 1)), - ], - axis=1, - ) - else: - ac = ac[:, : int(EEG['srate']) + 1] / ( - ac[:, 0][:, np.newaxis] - * np.concatenate( - ( - np.arange(n_points, n_points - int(EEG['srate']), -1), - np.array([max(1, n_points - int(EEG['srate']))]), - ) - ) - / n_points - ) - - # resample to 1 second at 100 samples/sec - ac = resample_poly(ac.T, up=100, down=EEG['srate']).T - ac = ac[:, 1:101] - - return ac diff --git a/src/eegprep/plugins/ICLabel/eeg_icalabelstat.py b/src/eegprep/plugins/ICLabel/eeg_icalabelstat.py deleted file mode 100644 index f3cb1a7b..00000000 --- a/src/eegprep/plugins/ICLabel/eeg_icalabelstat.py +++ /dev/null @@ -1,150 +0,0 @@ -"""EEGLAB ICLabel component summary statistics.""" - -from __future__ import annotations - -import sys -from typing import Any, TextIO - -import numpy as np - - -DEFAULT_ICLABEL_CLASSES = ("Brain", "Muscle", "Eye", "Heart", "Line Noise", "Channel Noise", "Other") - - -def eeg_icalabelstat( - EEG: dict[str, Any], - threshold: float | list[float] | tuple[float, ...] | np.ndarray = 0.9, - *, - verbose: bool = True, - stream: TextIO | None = None, -) -> dict[str, Any]: - """Return and optionally print ICLabel class statistics. - - This mirrors EEGLAB's ``plugins/ICLabel/eeg_icalabelstat.m`` summary, - including class order and the historical ``IClabel`` console label, while - returning structured Python statistics for programmatic use. - - Args: - EEG: EEGPrep/EEGLAB-style dataset with ICLabel classifications under - ``EEG["etc"]["ic_classification"]["ICLabel"]``. - threshold: Scalar probability threshold applied to every class, or one - threshold per ICLabel class. - verbose: Print EEGLAB-style summary lines when true. - stream: Destination for printed summary lines. Defaults to stdout. - - Returns: - Dictionary containing class names, thresholds, per-class counts, - 1-based component indices above threshold, mean probabilities, dominant - class counts, and current rejected/kept tallies. - """ - iclabel_state = _iclabel_state(EEG) - classifications = _classification_matrix(iclabel_state) - classes = _class_names(iclabel_state, classifications.shape[1]) - thresholds = _threshold_vector(threshold, len(classes)) - flags = _rejection_flags(EEG, classifications.shape[0]) - - above_threshold = classifications > thresholds[np.newaxis, :] - counts = np.sum(above_threshold, axis=0).astype(int) - component_indices = [ - (np.flatnonzero(above_threshold[:, class_index]) + 1).astype(int).tolist() - for class_index in range(len(classes)) - ] - rejected_counts = np.sum(above_threshold & flags[:, np.newaxis], axis=0).astype(int) - kept_counts = np.sum(above_threshold & ~flags[:, np.newaxis], axis=0).astype(int) - dominant_class_indices = np.argmax(classifications, axis=1) - - stats = { - "classes": classes, - "threshold": thresholds, - "component_count": int(classifications.shape[0]), - "counts": counts, - "component_indices": component_indices, - "mean_probability": np.mean(classifications, axis=0), - "dominant_counts": np.bincount(dominant_class_indices, minlength=len(classes)).astype(int), - "dominant_class_indices": (dominant_class_indices + 1).astype(int), - "rejected_component_count": int(np.sum(flags)), - "kept_component_count": int(classifications.shape[0] - np.sum(flags)), - "rejected_counts": rejected_counts, - "kept_counts": kept_counts, - } - if verbose: - _print_summary(stats, stream or sys.stdout) - return stats - - -def _iclabel_state(EEG: dict[str, Any]) -> dict[str, Any]: - try: - state = EEG["etc"]["ic_classification"]["ICLabel"] - except KeyError as exc: - raise ValueError("No ICLabel classifications found. Run pop_iclabel first.") from exc - if not isinstance(state, dict): - raise ValueError("No ICLabel classifications found. Run pop_iclabel first.") - return state - - -def _classification_matrix(iclabel_state: dict[str, Any]) -> np.ndarray: - try: - classifications = np.asarray(iclabel_state["classifications"], dtype=float) - except KeyError as exc: - raise ValueError("No ICLabel classifications found. Run pop_iclabel first.") from exc - if classifications.ndim != 2 or classifications.size == 0: - raise ValueError("ICLabel classifications must be a non-empty 2D array.") - if not np.isfinite(classifications).all(): - raise ValueError("ICLabel classifications must contain finite probabilities.") - return classifications - - -def _class_names(iclabel_state: dict[str, Any], class_count: int) -> list[str]: - classes = iclabel_state.get("classes") - if classes is None or np.asarray(classes, dtype=object).size == 0: - if class_count == len(DEFAULT_ICLABEL_CLASSES): - return list(DEFAULT_ICLABEL_CLASSES) - raise ValueError("ICLabel classes are missing and classifications do not have the 7 standard columns.") - names = [_string_name(value) for value in np.asarray(classes, dtype=object).ravel().tolist()] - if len(names) != class_count: - raise ValueError(f"ICLabel class list has {len(names)} labels for {class_count} probability columns.") - return names - - -def _string_name(value: Any) -> str: - if isinstance(value, bytes): - return value.decode("utf-8") - return str(value) - - -def _threshold_vector(threshold: float | list[float] | tuple[float, ...] | np.ndarray, class_count: int) -> np.ndarray: - values = np.asarray(threshold, dtype=float) - if values.ndim == 0 or values.size == 1: - thresholds = np.repeat(float(values.reshape(-1)[0]), class_count) - else: - thresholds = values.reshape(-1) - if thresholds.size != class_count: - raise ValueError(f"threshold must be scalar or contain {class_count} values.") - if not np.isfinite(thresholds).all() or (thresholds < 0).any() or (thresholds > 1).any(): - raise ValueError("ICLabel thresholds must be finite probabilities between 0 and 1.") - return thresholds.astype(float, copy=True) - - -def _rejection_flags(EEG: dict[str, Any], component_count: int) -> np.ndarray: - reject = EEG.get("reject") or {} - flags = reject.get("gcompreject") if isinstance(reject, dict) else None - if flags is None: - return np.zeros(component_count, dtype=bool) - normalized = np.asarray(flags).reshape(-1) - if normalized.size != component_count: - raise ValueError("EEG.reject.gcompreject must match the number of ICLabel components.") - return normalized.astype(bool) - - -def _print_summary(stats: dict[str, Any], stream: TextIO) -> None: - classes = stats["classes"] - counts = stats["counts"] - thresholds = stats["threshold"] - component_count = stats["component_count"] - for class_name, count, class_threshold in zip(classes, counts, thresholds): - label = f'IClabel class "{class_name}"' - percent = int(np.round(float(class_threshold) * 100)) - print(f"{label:>30}: {int(count)}/{component_count} components at {percent}% threshold", file=stream) - - -__all__ = ["DEFAULT_ICLABEL_CLASSES", "eeg_icalabelstat"] diff --git a/src/eegprep/plugins/ICLabel/eeg_icflag.py b/src/eegprep/plugins/ICLabel/eeg_icflag.py deleted file mode 100644 index 3be6ab1b..00000000 --- a/src/eegprep/plugins/ICLabel/eeg_icflag.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Flag independent components based on ICLabel classifications.""" - -import numpy as np - - -def eeg_icflag(EEG, thresholds): - """Flag independent components based on ICLabel classification probabilities. - - Parameters - ---------- - EEG : dict - EEG structure with ICLabel classifications in EEG['etc']['ic_classification']['ICLabel']['classifications'] - thresholds : array-like, shape (7, 2) - Threshold matrix where each row corresponds to an IC class: - [Brain, Muscle, Eye, Heart, Line Noise, Channel Noise, Other] - Each row contains [min_threshold, max_threshold]. - Use NaN in either column to ignore a class, matching EEGLAB's blank - threshold fields. - - Returns - ------- - EEG : dict - EEG structure with added 'reject' field containing flags for each component - - Examples - -------- - # Flag components with Muscle > 0.9 OR Eye > 0.9 - thresholds = np.array([ - [np.nan, np.nan], # Brain - [0.9, 1.0], # Muscle - [0.9, 1.0], # Eye - [np.nan, np.nan], # Heart - [np.nan, np.nan], # Line Noise - [np.nan, np.nan], # Channel Noise - [np.nan, np.nan], # Other - ]) - EEG = eeg_icflag(EEG, thresholds) - """ - try: - ic_class = EEG["etc"]["ic_classification"]["ICLabel"]["classifications"] - except KeyError as exc: - raise ValueError("EEG structure does not contain ICLabel classifications") from exc - n_comps = ic_class.shape[0] - - thresholds = np.asarray(thresholds, dtype=float) - - if thresholds.shape != (7, 2): - raise ValueError("Thresholds must be a 7x2 array") - - reject = np.zeros(n_comps, dtype=bool) - - for class_idx in range(7): - min_thresh = thresholds[class_idx, 0] - max_thresh = thresholds[class_idx, 1] - if np.isnan(min_thresh) or np.isnan(max_thresh): - continue - probs = ic_class[:, class_idx] - reject |= (probs > min_thresh) & (probs < max_thresh) - - EEG["reject"] = dict(EEG.get("reject") or {}) - EEG["reject"]["gcompreject"] = reject.astype(int) - - return EEG diff --git a/src/eegprep/plugins/ICLabel/eeg_rpsd.py b/src/eegprep/plugins/ICLabel/eeg_rpsd.py deleted file mode 100644 index 7855fd8c..00000000 --- a/src/eegprep/plugins/ICLabel/eeg_rpsd.py +++ /dev/null @@ -1,72 +0,0 @@ -"""EEG relative power spectral density computation.""" - -import numpy as np -from numpy.fft import fft - - -def eeg_rpsd(EEG, nfreqs=None, pct_data=100): - """Compute relative power spectral density for ICA components. - - Parameters - ---------- - EEG : dict - EEG data structure with ICA activations. - nfreqs : int, optional - Number of frequency bins. Default is Nyquist frequency. - pct_data : float, optional - Percentage of data to use. Default is 100. - - Returns - ------- - ndarray - Power spectral density in dB for each component. - """ - # clean input cutoff freq - nyquist = EEG['srate'] // 2 - if nfreqs is None or nfreqs > nyquist: - nfreqs = nyquist - nfreqs = int(nfreqs) - - # setup constants - ncomp = EEG['icaweights'].shape[0] - - # Hamming window - n_points = min(EEG['pnts'], EEG['srate']) - m = n_points - isOddLength = m % 2 - if isOddLength: - x = np.arange(0, (m - 1) / 2 + 1) / (m - 1) - else: - x = np.arange(0, m / 2) / (m - 1) - - a = 0.54 - window = a - (1 - a) * np.cos(2 * np.pi * x) - if isOddLength: - window = np.concatenate([window, window[-2::-1]]) - else: - window = np.concatenate([window, window[::-1]]) - - cutoff = (EEG['pnts'] // n_points) * n_points - index = ( - np.add.outer(np.ceil(np.arange(0, cutoff - n_points + 1, n_points / 2)).astype(int), np.arange(0, n_points)) - .astype(int) - .transpose() - ) - - rng = np.random.RandomState(0) # rng('default') in MATLAB; local RNG avoids mutating global state - n_seg = index.shape[1] * EEG['trials'] - subset = rng.permutation(n_seg)[: int(n_seg * pct_data / 100)] - - # calculate windowed spectrums - psdmed = np.zeros((ncomp, nfreqs)) - for it in range(ncomp): - temp = np.reshape(EEG['icaact'][it, index, :], (1, index.shape[0], index.shape[1] * EEG['trials'])) - temp = temp[:, :, subset] * window[:, np.newaxis] - temp = fft(temp, int(n_points), axis=1) - temp = np.abs(temp) ** 2 - temp = temp[:, 1 : nfreqs + 1, :] * 2 / (EEG['srate'] * np.sum(window**2)) - if nfreqs == nyquist: - temp[:, -1, :] /= 2 - psdmed[it, :] = 20 * np.log10(np.median(temp, axis=2)) - - return psdmed diff --git a/src/eegprep/plugins/ICLabel/iclabel.py b/src/eegprep/plugins/ICLabel/iclabel.py deleted file mode 100644 index 1961018b..00000000 --- a/src/eegprep/plugins/ICLabel/iclabel.py +++ /dev/null @@ -1,123 +0,0 @@ -"""ICLabel module for classifying independent components in EEG data.""" - -from copy import deepcopy -import os - -import numpy as np - - -_SUPPORTED_ALGORITHMS = ('default', 'lite', 'beta') - - -def iclabel(EEG, algorithm='default', engine=None): - """Apply ICLabel to classify independent components. - - Parameters - ---------- - EEG : dict - EEGLAB EEG structure - algorithm : str - Algorithm to use for classification, passed to the MATLAB/Octave implementation. - Default is 'default'. - engine : str or None - Engine to use for implementation. Options are: - - None: Use the default Python implementation - - 'matlab': Use MATLAB engine - - 'octave': Use Octave engine - - Returns - ------- - EEG : dict - EEGLAB EEG structure with ICLabel classifications added - """ - algorithm = _normalize_algorithm(algorithm) - EEG = deepcopy(EEG) - - # Check if using MATLAB or Octave implementation - if engine in ['matlab', 'octave']: - from eegprep.functions.adminfunc.eeglabcompat import get_eeglab - - # Determine which engine to use - runtime = 'MAT' if engine == 'matlab' else 'OCT' - eeglab = get_eeglab(runtime=runtime) - - # Run ICLabel using MATLAB/Octave, passing the algorithm parameter - if algorithm == 'default': - return eeglab.iclabel(EEG) - else: - return eeglab.iclabel(EEG, algorithm) - - # Default Python implementation - elif engine is None: - if algorithm != 'default': - raise NotImplementedError( - "EEGPrep standalone Python ICLabel only ships the default network (netICL.mat). " - f"The '{algorithm}' network is available only with engine='matlab' or engine='octave' " - "and an EEGLAB ICLabel checkout that provides that artifact." - ) - try: - import torch - except ImportError as e: - raise ImportError( - f"PyTorch is not installed in your environment ({e}). " - f"To include torch, install eegprep as eegprep[all] or " - f"install the torch package manually (see Getting Started " - f"on pytorch.org for specifics for your platform)." - ) from e - - from eegprep.plugins.ICLabel.iclabel_net import ICLabelNet - from eegprep import ICL_feature_extractor - - # ICLABEL Extract ICLabel features from an EEG dataset. - features = ICL_feature_extractor(EEG, True) - - # Equivalent of MATLAB code reshaping - features[0] = np.single( - np.concatenate([features[0], -features[0], features[0][:, ::-1, :, :], -features[0][:, ::-1, :, :]], axis=3) - ) - features[1] = np.single(np.tile(features[1], (1, 1, 1, 4))) - features[2] = np.single(np.tile(features[2], (1, 1, 1, 4))) - # print('Feature 0 shape:', features[0].shape) - # print('Feature 1 shape:', features[1].shape) - # print('Feature 2 shape:', features[2].shape) - - # Load the ICLabelNet model - base_dir = os.path.dirname(os.path.abspath(__file__)) - data_path = os.path.join(base_dir, 'netICL.mat') - model = ICLabelNet(data_path) - - # Convert the features to torch tensors - image = torch.tensor(features[0]).permute(-1, 2, 0, 1) - psdmed = torch.tensor(features[1]).permute(-1, 2, 0, 1) - autocorr = torch.tensor(features[2]).permute(-1, 2, 0, 1) - - # Get the output from the model - output = model(image, psdmed, autocorr) - output_np = output.detach().numpy() - output_np = output_np.T # Transpose the array - output_np = np.reshape(output_np, (-1, 4), order='F') # Reshape to have 4 columns - output_np = np.mean(output_np, axis=1) # Compute the mean along the second axis (columns) - output_np = np.reshape(output_np, (7, -1), order='F') # Reshape to have 7 rows - output_np = output_np.T # Transpose back - - if 'ic_classification' not in EEG['etc']: - EEG['etc']['ic_classification'] = {} - if 'ICLabel' not in EEG['etc']['ic_classification']: - EEG['etc']['ic_classification']['ICLabel'] = {} - - EEG['etc']['ic_classification']['ICLabel']['classes'] = np.array( - ['Brain', 'Muscle', 'Eye', 'Heart', 'Line Noise', 'Channel Noise', 'Other'], dtype=object - ) - EEG['etc']['ic_classification']['ICLabel']['classifications'] = output_np - EEG['etc']['ic_classification']['ICLabel']['version'] = algorithm - - return EEG - else: - raise ValueError(f"Unsupported engine: {engine}. Should be None, 'matlab', or 'octave'") - - -def _normalize_algorithm(algorithm): - normalized = 'default' if algorithm is None else str(algorithm).lower() - if normalized not in _SUPPORTED_ALGORITHMS: - raise ValueError("algorithm must be one of 'default', 'lite', or 'beta'") - return normalized diff --git a/src/eegprep/plugins/ICLabel/iclabel_net.py b/src/eegprep/plugins/ICLabel/iclabel_net.py deleted file mode 100644 index 169fcbfe..00000000 --- a/src/eegprep/plugins/ICLabel/iclabel_net.py +++ /dev/null @@ -1,266 +0,0 @@ -"""ICLabel neural network model for EEG artifact classification. - -This module provides PyTorch implementations of the ICLabel neural network for -classifying EEG components as brain or artifact sources. -""" - -import scipy.io -import torch -import scipy - - -class Reshape(torch.nn.Module): - """Custom reshape layer for PyTorch neural networks.""" - - def __init__(self, shape): - """Initialize reshape layer. - - Parameters - ---------- - shape : tuple - Target shape for reshaping. - """ - super().__init__() - self.shape = shape - - def forward(self, x): - """Forward pass for reshaping. - - Parameters - ---------- - x : torch.Tensor - Input tensor. - - Returns - ------- - torch.Tensor - Reshaped tensor. - """ - return x.view(x.shape[0], *self.shape) - - -class Concatenate(torch.nn.Module): - """Custom concatenation layer for PyTorch neural networks.""" - - def __init__(self, dim): - """Initialize concatenation layer. - - Parameters - ---------- - dim : int - Dimension along which to concatenate. - """ - super().__init__() - self.dim = dim - - def forward(self, x: list): - """Forward pass for concatenation. - - Parameters - ---------- - x : list - List of tensors to concatenate. - - Returns - ------- - torch.Tensor - Concatenated tensor. - """ - return torch.cat(x, dim=self.dim) - - -class ICLabelNet(torch.nn.Module): - """ICLabel neural network for EEG component classification.""" - - def __init__(self, mat_path): - """Initialize ICLabel network from MATLAB weights. - - Parameters - ---------- - mat_path : str - Path to MATLAB .mat file containing network weights. - """ - super().__init__() - iclabel_matlab = scipy.io.loadmat(mat_path) - params = iclabel_matlab['params'][0] - # i = 11 - # print('shape of param', i, torch.tensor(params[i][1]).shape) - self.discriminator_image_layer1_conv = torch.nn.Conv2d( - in_channels=1, out_channels=128, kernel_size=4, stride=2, padding=1, dilation=1 - ) - # print(self.discriminator_image_layer1_conv.weight.shape) - self.discriminator_image_layer1_conv.weight = torch.nn.Parameter( - torch.tensor(params[0][1], dtype=torch.float32).permute(3, 2, 0, 1) - ) - self.discriminator_image_layer1_conv.bias = torch.nn.Parameter( - torch.tensor(params[1][1], dtype=torch.float32).squeeze() - ) - self.discriminator_image_layer1_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_image_layer2_conv = torch.nn.Conv2d( - in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, dilation=1 - ) - self.discriminator_image_layer2_conv.weight = torch.nn.Parameter( - torch.tensor(params[2][1], dtype=torch.float32).permute(3, 2, 0, 1) - ) - self.discriminator_image_layer2_conv.bias = torch.nn.Parameter( - torch.tensor(params[3][1], dtype=torch.float32).squeeze() - ) - self.discriminator_image_layer2_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_image_layer3_conv = torch.nn.Conv2d( - in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, dilation=1 - ) - self.discriminator_image_layer3_conv.weight = torch.nn.Parameter( - torch.tensor(params[4][1], dtype=torch.float32).permute(3, 2, 0, 1) - ) - self.discriminator_image_layer3_conv.bias = torch.nn.Parameter( - torch.tensor(params[5][1], dtype=torch.float32).squeeze() - ) - self.discriminator_image_layer3_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_layer1_conv_conv = torch.nn.Conv2d( - in_channels=1, out_channels=128, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_psdmed_layer1_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[6][1], dtype=torch.float32).permute(3, 2, 0, 1) - ) - self.discriminator_psdmed_layer1_conv_conv.bias = torch.nn.Parameter( - torch.tensor(params[7][1], dtype=torch.float32).squeeze() - ) - self.discriminator_psdmed_layer1_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_layer2_conv_conv = torch.nn.Conv2d( - in_channels=128, out_channels=256, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_psdmed_layer2_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[8][1], dtype=torch.float32).permute(3, 2, 0, 1) - ) - self.discriminator_psdmed_layer2_conv_conv.bias = torch.nn.Parameter( - torch.tensor(params[9][1], dtype=torch.float32).squeeze() - ) - self.discriminator_psdmed_layer2_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_layer3_conv_conv = torch.nn.Conv2d( - in_channels=256, out_channels=1, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_psdmed_layer3_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[10][1], dtype=torch.float32).unsqueeze(3).permute(3, 2, 0, 1) - ) - self.discriminator_psdmed_layer3_conv_conv.bias = torch.nn.Parameter( - torch.tensor(params[11][1], dtype=torch.float32).squeeze(1) - ) - self.discriminator_psdmed_layer3_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_autocorr_layer1_conv_conv = torch.nn.Conv2d( - in_channels=1, out_channels=128, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_autocorr_layer1_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[12][1], dtype=torch.float32).permute(3, 2, 0, 1) - ) - self.discriminator_autocorr_layer1_conv_conv.bias = torch.nn.Parameter( - torch.tensor(params[13][1], dtype=torch.float32).squeeze() - ) - self.discriminator_autocorr_layer1_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_autocorr_layer2_conv_conv = torch.nn.Conv2d( - in_channels=128, out_channels=256, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_autocorr_layer2_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[14][1], dtype=torch.float32).permute(3, 2, 0, 1) - ) - self.discriminator_autocorr_layer2_conv_conv.bias = torch.nn.Parameter( - torch.tensor(params[15][1], dtype=torch.float32).squeeze() - ) - self.discriminator_autocorr_layer2_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_autocorr_layer3_conv_conv = torch.nn.Conv2d( - in_channels=256, out_channels=1, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_autocorr_layer3_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[16][1], dtype=torch.float32).unsqueeze(3).permute(3, 2, 0, 1) - ) - self.discriminator_autocorr_layer3_conv_conv.bias = torch.nn.Parameter( - torch.tensor(params[17][1], dtype=torch.float32).squeeze(1) - ) - self.discriminator_autocorr_layer3_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_reshape = Reshape((100, 1, 1)) - self.discriminator_psdmed_concat1 = Concatenate(dim=2) - self.discriminator_psdmed_concat2 = Concatenate(dim=3) - self.discriminator_autocorr_reshape = Reshape((100, 1, 1)) - self.discriminator_autocorr_concat1 = Concatenate(dim=2) - self.discriminator_autocorr_concat2 = Concatenate(dim=3) - self.discriminator_concat = Concatenate(dim=1) - self.discriminator_conv = torch.nn.Conv2d( - in_channels=712, out_channels=7, kernel_size=4, stride=1, padding=0, dilation=1 - ) - self.discriminator_conv.weight = torch.nn.Parameter(torch.tensor(params[18][1]).permute(3, 2, 0, 1)) - self.discriminator_conv.bias = torch.nn.Parameter(torch.tensor(params[19][1]).squeeze()) - self.discriminator_softmax = torch.nn.Softmax(dim=1) - - def forward(self, image, psdmed, autocorr): - """Forward pass through the ICLabelNet model. - - Parameters - ---------- - image : torch.Tensor - Input image tensor. - psdmed : torch.Tensor - PSD median tensor. - autocorr : torch.Tensor - Autocorrelation tensor. - - Returns - ------- - torch.Tensor - Output tensor after softmax. - """ - x_image = self.discriminator_image_layer1_conv(image) - x_image = self.discriminator_image_layer1_relu(x_image) - x_image = self.discriminator_image_layer2_conv(x_image) - x_image = self.discriminator_image_layer2_relu(x_image) - x_image = self.discriminator_image_layer3_conv(x_image) - x_image = self.discriminator_image_layer3_relu(x_image) - # print('x_image', x_image.shape) - - x_psdmed = self.discriminator_psdmed_layer1_conv_conv(psdmed) - x_psdmed = self.discriminator_psdmed_layer1_conv_relu(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer2_conv_conv(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer2_conv_relu(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer3_conv_conv(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer3_conv_relu(x_psdmed) - x_psdmed = self.discriminator_psdmed_reshape(x_psdmed) - x_psdmed = self.discriminator_psdmed_concat1([x_psdmed] * 4) - x_psdmed = self.discriminator_psdmed_concat2([x_psdmed] * 4) - # print('x_psdmed', x_psdmed.shape) - - x_autocorr = self.discriminator_autocorr_layer1_conv_conv(autocorr) - x_autocorr = self.discriminator_autocorr_layer1_conv_relu(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer2_conv_conv(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer2_conv_relu(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer3_conv_conv(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer3_conv_relu(x_autocorr) - x_autocorr = self.discriminator_autocorr_reshape(x_autocorr) - x_autocorr = self.discriminator_autocorr_concat1([x_autocorr] * 4) - x_autocorr = self.discriminator_autocorr_concat2([x_autocorr] * 4) - # print('x_autocorr', x_autocorr.shape) - - x = self.discriminator_concat([x_image, x_psdmed, x_autocorr]) - x = self.discriminator_conv(x) - # print('x', x.shape) - # subtract max value to avoid overflow - x = x - torch.max(x, dim=1, keepdim=True).values - x = self.discriminator_softmax(x) - - return x - - -# if __name__ == "__main__": -# model = ICLabelNet('netICL.mat') -# image_mat = scipy.io.loadmat('net_vars.mat')['in_image'] -# psdmed_mat = scipy.io.loadmat('net_vars.mat')['in_psdmed'] -# autocorr_mat = scipy.io.loadmat('net_vars.mat')['in_autocorr'] -# # assuming third dimension is trivial and last dimension is channel. First two dimensions (32 x 32) are size of topoplot -# image = torch.tensor(image_mat).permute(-1, 2, 0, 1) -# print('image shape', image.shape) -# psdmed = torch.tensor(psdmed_mat).permute(-1, 2, 0, 1) -# print('psd shape', psdmed.shape) -# autocorr = torch.tensor(autocorr_mat).permute(-1, 2, 0, 1) -# print('autocorr shape', autocorr.shape) -# output = model(image, psdmed, autocorr) -# print(output.shape) - -# # save the output to a mat file -# scipy.io.savemat('output4.mat', {'output': output.detach().numpy()}) diff --git a/src/eegprep/plugins/ICLabel/menu.py b/src/eegprep/plugins/ICLabel/menu.py deleted file mode 100644 index 2bf728a1..00000000 --- a/src/eegprep/plugins/ICLabel/menu.py +++ /dev/null @@ -1,39 +0,0 @@ -"""ICLabel and viewprops plugin menu specs for the EEGPrep main window.""" - -from __future__ import annotations - -from eegprep.functions.guifunc.menu_spec import MenuItemSpec, menu_item - - -def iclabel_menu() -> MenuItemSpec: - """Return the EEGLAB ICLabel Tools submenu.""" - return menu_item( - "Classify components using ICLabel", - userdata="startup:off;study:on;ica:on;roi:off", - origin="ICLabel", - children=[ - menu_item( - "Label components", action="pop_iclabel", userdata="startup:off;study:on;ica:on", origin="ICLabel" - ), - menu_item( - "Flag components as artifacts", - action="pop_icflag", - userdata="startup:off;study:on;ica:on", - origin="ICLabel", - ), - menu_item( - "View extended component properties", - action="pop_viewprops:components", - userdata="startup:off;ica:on", - origin="ICLabel", - ), - ], - ) - - -def viewprops_plot_menus() -> tuple[MenuItemSpec, MenuItemSpec]: - """Return viewprops Plot menu additions.""" - return ( - menu_item("View extended channel properties", action="pop_viewprops:channels", origin="viewprops"), - menu_item("View extended component properties", action="pop_viewprops:components", origin="viewprops"), - ) diff --git a/src/eegprep/plugins/ICLabel/netICL.mat b/src/eegprep/plugins/ICLabel/netICL.mat deleted file mode 100644 index f440f406..00000000 Binary files a/src/eegprep/plugins/ICLabel/netICL.mat and /dev/null differ diff --git a/src/eegprep/plugins/ICLabel/pop_icflag.py b/src/eegprep/plugins/ICLabel/pop_icflag.py deleted file mode 100644 index 0ff1d3ef..00000000 --- a/src/eegprep/plugins/ICLabel/pop_icflag.py +++ /dev/null @@ -1,168 +0,0 @@ -"""EEGLAB-style pop wrapper for ICLabel component flagging.""" - -from __future__ import annotations - -import copy -from typing import Any - -import numpy as np - -from eegprep.functions.guifunc.inputgui import inputgui -from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._pop_utils import format_history_value -from eegprep.plugins.ICLabel.eeg_icalabelstat import eeg_icalabelstat -from eegprep.plugins.ICLabel.eeg_icflag import eeg_icflag - - -ICLABEL_CLASSES = ("Brain", "Muscle", "Eye", "Heart", "Line Noise", "Channel Noise", "Other") -DEFAULT_ICFLAG_THRESHOLDS = np.array( - [ - [np.nan, np.nan], - [0.9, 1.0], - [0.9, 1.0], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - [np.nan, np.nan], - ], - dtype=float, -) - - -def pop_icflag( - EEG: dict | list[dict], - thresholds: Any = None, - *, - gui: bool | None = None, - renderer: Any | None = None, - return_com: bool = False, -): - """Flag ICLabel-classified components for later rejection. - - Args: - EEG: EEG dictionary, or list of EEG dictionaries. - thresholds: 7-by-2 threshold matrix in ICLabel class order. - gui: Force or suppress the GUI threshold dialog. - renderer: Optional GUI renderer used by tests. - return_com: Return ``(EEG, command)`` when true. - - Returns: - dict or tuple: Updated EEG, and optionally the history command. - """ - if EEG is None: - return (None, "") if return_com else None - if gui is None: - gui = thresholds is None - if gui: - if thresholds is None and not isinstance(EEG, list): - _require_iclabel(EEG) - eeg_icalabelstat(EEG) - result = _run_gui(renderer=renderer) - if result is None: - return (EEG, "") if return_com else EEG - thresholds = result["thresholds"] - if thresholds is None: - thresholds = DEFAULT_ICFLAG_THRESHOLDS - thresholds = _normalize_thresholds(thresholds) - - if isinstance(EEG, list): - output = [_flag_dataset(item, thresholds) for item in EEG] - command = _history_command(thresholds) - return (output, command) if return_com else output - - output = _flag_dataset(EEG, thresholds) - command = _history_command(thresholds) - return (output, command) if return_com else output - - -def _flag_dataset(EEG: dict, thresholds: np.ndarray) -> dict: - _require_iclabel(EEG) - return eeg_icflag(copy.deepcopy(EEG), thresholds) - - -def pop_icflag_dialog_spec() -> DialogSpec: - """Return the EEGLAB-like dialog spec for ``pop_icflag``.""" - controls: list[ControlSpec] = [ - ControlSpec("text", "Select range for flagging component for rejection", font_weight="bold"), - ControlSpec("spacer"), - ControlSpec("text", "Min"), - ControlSpec("text", "Max"), - ] - geometry: list[tuple[float, ...] | tuple[int, ...]] = [(1,), (2.0, 0.4, 0.4)] - for index, label in enumerate(ICLABEL_CLASSES): - controls.extend( - [ - ControlSpec("text", f'Probability range for "{label}"'), - ControlSpec("edit", tag=f"min_{index}", value=_threshold_text(DEFAULT_ICFLAG_THRESHOLDS[index, 0])), - ControlSpec("edit", tag=f"max_{index}", value=_threshold_text(DEFAULT_ICFLAG_THRESHOLDS[index, 1])), - ] - ) - geometry.append((2.0, 0.4, 0.4)) - - return DialogSpec( - title="Flag components using ICLabel -- pop_icflag()", - function_name="pop_icflag", - eeglab_source="plugins/ICLabel/pop_icflag.m", - geometry=tuple(geometry), - size=(467, 440), - content_margins=(23, 37, 23, 13), - row_spacing=16, - controls=tuple(controls), - ) - - -def _run_gui(renderer: Any | None = None) -> dict[str, np.ndarray] | None: - result = inputgui(pop_icflag_dialog_spec(), renderer=renderer) - if result is None: - return None - thresholds = [] - for index in range(len(ICLABEL_CLASSES)): - thresholds.append( - [_parse_threshold(result.get(f"min_{index}", "")), _parse_threshold(result.get(f"max_{index}", ""))] - ) - return {"thresholds": _normalize_thresholds(thresholds)} - - -def _parse_threshold(value: Any) -> float: - text = str(value).strip() - if not text: - return float("nan") - return float(text) - - -def _normalize_thresholds(thresholds: Any) -> np.ndarray: - normalized = np.asarray(thresholds, dtype=float) - if normalized.shape != (7, 2): - raise ValueError("thresholds must be a 7x2 array") - finite = normalized[np.isfinite(normalized)] - if finite.size and ((finite < 0).any() or (finite > 1).any()): - raise ValueError("ICLabel thresholds must be between 0 and 1") - for row in normalized: - if np.isfinite(row).all() and row[0] > row[1]: - raise ValueError("ICLabel threshold minimum cannot exceed maximum") - return normalized - - -def _require_iclabel(EEG: dict) -> None: - try: - classifications = EEG["etc"]["ic_classification"]["ICLabel"]["classifications"] - except KeyError as exc: - raise ValueError("No ICLabel classifications found. Run pop_iclabel first.") from exc - if np.asarray(classifications).size == 0: - raise ValueError("No ICLabel classifications found. Run pop_iclabel first.") - - -def _threshold_text(value: float) -> str: - return "" if np.isnan(value) else f"{value:g}" - - -def _history_command(thresholds: np.ndarray) -> str: - return f"EEG = pop_icflag(EEG, thresholds={_history_thresholds(thresholds)});" - - -def _history_thresholds(thresholds: np.ndarray) -> str: - values = [ - [None if np.isnan(value) else float(value) for value in row] - for row in np.asarray(thresholds, dtype=float).tolist() - ] - return format_history_value(values, cell_for_sequence=None) diff --git a/src/eegprep/plugins/ICLabel/pop_iclabel.py b/src/eegprep/plugins/ICLabel/pop_iclabel.py deleted file mode 100644 index 092286aa..00000000 --- a/src/eegprep/plugins/ICLabel/pop_iclabel.py +++ /dev/null @@ -1,79 +0,0 @@ -"""EEGLAB-style pop wrapper for ICLabel.""" - -from __future__ import annotations - -import numpy as np - -from eegprep.functions.guifunc.inputgui import inputgui -from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.plugins.ICLabel.iclabel import iclabel - - -_VERSIONS = ("default", "lite", "beta") - - -def pop_iclabel( - EEG, - icversion: str | None = None, - *, - gui: bool | None = None, - renderer=None, - engine=None, - return_com: bool = False, -): - """Classify independent components using ICLabel.""" - if EEG is None: - return (None, "") if return_com else None - if gui is None: - gui = icversion is None - if gui: - result = _run_gui(renderer=renderer) - if result is None: - return (EEG, "") if return_com else EEG - icversion = result["icversion"] - icversion = "default" if icversion is None else str(icversion).lower() - if icversion not in _VERSIONS: - raise ValueError("icversion must be one of 'default', 'lite', or 'beta'") - if isinstance(EEG, list): - output = [pop_iclabel(item, icversion, gui=False, engine=engine) for item in EEG] - command = _history_command(icversion) - return (output, command) if return_com else output - _require_ica(EEG) - output = iclabel(EEG, algorithm=icversion, engine=engine) - command = _history_command(icversion) - return (output, command) if return_com else output - - -def pop_iclabel_dialog_spec() -> DialogSpec: - """Return the EEGLAB-like dialog spec for ``pop_iclabel``.""" - return DialogSpec( - title="ICLabel", - function_name="pop_iclabel", - eeglab_source="plugins/ICLabel/pop_iclabel.m", - geometry=((1,), (1,)), - size=(356, 199), - show_help_button=False, - controls=( - ControlSpec("text", "Select which icversion of ICLabel to use:"), - ControlSpec("popupmenu", "Default (recommended)|Lite|Beta", tag="icversion", value=1), - ), - ) - - -def _run_gui(renderer=None): - result = inputgui(pop_iclabel_dialog_spec(), renderer=renderer) - if result is None: - return None - index = int(result.get("icversion", 1)) - 1 - index = max(0, min(index, len(_VERSIONS) - 1)) - return {"icversion": _VERSIONS[index]} - - -def _require_ica(EEG): - weights = EEG.get("icaweights") - if weights is None or np.asarray(weights).size == 0: - raise ValueError("ICLabel requires an ICA decomposition. Run pop_runica first.") - - -def _history_command(icversion): - return f"EEG = pop_iclabel(EEG, '{icversion}');" diff --git a/src/eegprep/plugins/ICLabel/pop_prop_extended.py b/src/eegprep/plugins/ICLabel/pop_prop_extended.py deleted file mode 100644 index 42a16a2a..00000000 --- a/src/eegprep/plugins/ICLabel/pop_prop_extended.py +++ /dev/null @@ -1,189 +0,0 @@ -"""ICLabel extended channel/component property dashboard.""" - -from __future__ import annotations - -from typing import Any - -from eegprep.functions.guifunc.inputgui import inputgui -from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._plot_utils import history_command -from eegprep.plugins.ICLabel._prop_browser import build_navigable_dashboard -from eegprep.plugins.ICLabel._prop_numerics import ( - DEFAULT_ICLABEL_CLASSES, - ClassifierData, - DipfitData, - ExtendedPropertyData, - build_extended_property_data, - classifier_default_index, - classifier_name_from_gui, - classifier_names, - component_count, - component_rejection_status, - has_component_classifier, - resolve_classifier_data, - resolve_dipfit_data, - selected_property_indices, -) - - -def pop_prop_extended( - EEG: dict[str, Any] | None = None, - typecomp: int | bool = 1, - chanorcomp: Any = None, - winhandle: Any = None, - spec_opt: Any = None, - erp_opt: Any = None, - scroll_event: int | bool = 1, - classifier_name: str = "", - fig: Any = None, - *, - gui: bool | None = None, - renderer: Any | None = None, - plot: bool = True, - show_activity: bool = False, - reject_callback: Any | None = None, - return_com: bool = False, -): - """Render the EEGLAB viewprops-style extended property dashboard. - - Args: - EEG: EEGLAB-style EEG dictionary. - typecomp: ``1`` for channels, ``0`` for ICA components. - chanorcomp: EEGLAB-facing 1-based channel/component index or indices. - winhandle: Accepted for EEGLAB signature compatibility. - spec_opt: EEGLAB-style ``spectopo`` option text or parsed options. - erp_opt: Accepted for EEGLAB signature compatibility; image options are - currently interpreted by EEGPrep's native image summary. - scroll_event: Include events in the attached activity browser. - classifier_name: Component classifier field in ``EEG.etc.ic_classification``. - fig: Optional Matplotlib figure to reuse. - gui: Force or suppress the EEGLAB-like input dialog. - renderer: Optional dialog renderer for tests. - plot: Build the Matplotlib dashboard when true. - show_activity: Open the browser-backed activity window in addition to - attaching its model/window to the dashboard figure. - reject_callback: Optional callable invoked after OK commits component - rejection states. It receives ``(EEG, states)`` where ``states`` maps - EEGLAB-facing component indices to booleans. - return_com: Return ``(figure, command)`` when true. - """ - if EEG is None: - return (None, "") if return_com else None - typecomp = int(bool(typecomp)) - if gui is None: - gui = chanorcomp is None - if gui: - result = inputgui(pop_prop_extended_dialog_spec(EEG, typecomp), renderer=renderer) - if result is None: - return (None, "") if return_com else None - chanorcomp = result.get("chanorcomp", "1") - spec_opt = result.get("spec_opt", "") - erp_opt = result.get("erp_opt", "") - scroll_event = int(bool(result.get("scroll_event", True))) - if not typecomp: - classifier_name = classifier_name_from_gui(EEG, result.get("classifier_name", classifier_name)) - if chanorcomp is None: - chanorcomp = 1 - indices = selected_property_indices(EEG, typecomp, chanorcomp, default_all=False) - command = _history_command(typecomp, indices, winhandle, spec_opt, erp_opt, scroll_event, classifier_name) - figure = None - if plot: - figure = build_navigable_dashboard( - EEG, - typecomp, - indices, - winhandle, - spec_opt, - erp_opt, - scroll_event, - classifier_name, - fig=fig, - show_activity=show_activity, - reject_callback=reject_callback, - ) - return (figure, command) if return_com else figure - - -def pop_prop_extended_dialog_spec(EEG: dict[str, Any], typecomp: int | bool = 1) -> DialogSpec: - """Return the EEGLAB-like ``pop_prop_extended`` prompt.""" - is_channel = int(bool(typecomp)) - limit = int(EEG.get("nbchan", 0) or 0) if is_channel else component_count(EEG) - label = "Channel index(ices) to plot:" if is_channel else "Component index(ices) to plot:" - controls = [ - ControlSpec("text", label), - ControlSpec("edit", tag="chanorcomp", value="1" if limit else ""), - ControlSpec("text", "Spectral options (see spectopo() help):"), - ControlSpec("edit", tag="spec_opt", value=f"'freqrange', [2 {min(80, float(EEG.get('srate', 1)) / 2):g}]"), - ControlSpec("text", "Erpimage options (see erpimage() help):"), - ControlSpec("edit", tag="erp_opt", value=""), - ControlSpec( - "checkbox", - f"Draw events over scrolling {'channel' if is_channel else 'component'} activity", - tag="scroll_event", - value=True, - ), - ] - classifiers = classifier_names(EEG) if not is_channel else [] - if classifiers: - controls.append( - ControlSpec( - "popupmenu", - "|".join(classifiers), - tag="classifier_name", - value=classifier_default_index(classifiers), - ) - ) - return DialogSpec( - title=f"{'Channel' if is_channel else 'Component'} properties - pop_prop_extended()", - function_name="pop_prop_extended", - eeglab_source="plugins/ICLabel/viewprops/pop_prop_extended.m", - help_text="pophelp('pop_prop_extended')", - size=(600, 318 if classifiers else 288), - geometry=((1.3, 1), (1.3, 1), (1.3, 1), (1,), *((1,) if classifiers else ())), - controls=tuple(controls), - known_differences=( - "EEGPrep uses one navigable dashboard for multiple selected components instead of opening one " - "separate figure per component.", - ), - ) - - -def _history_command( - typecomp: int, - indices: list[int], - winhandle: Any, - spec_opt: Any, - erp_opt: Any, - scroll_event: int | bool, - classifier_name: str, -) -> str: - return history_command( - "pop_prop_extended", - int(bool(typecomp)), - indices if len(indices) != 1 else indices[0], - winhandle, - [] if spec_opt is None else spec_opt, - [] if erp_opt is None else erp_opt, - int(bool(scroll_event)), - classifier_name, - ) - - -__all__ = [ - "DEFAULT_ICLABEL_CLASSES", - "ClassifierData", - "DipfitData", - "ExtendedPropertyData", - "build_extended_property_data", - "classifier_default_index", - "classifier_name_from_gui", - "classifier_names", - "component_count", - "component_rejection_status", - "has_component_classifier", - "pop_prop_extended", - "pop_prop_extended_dialog_spec", - "resolve_classifier_data", - "resolve_dipfit_data", - "selected_property_indices", -] diff --git a/src/eegprep/plugins/ICLabel/pop_viewprops.py b/src/eegprep/plugins/ICLabel/pop_viewprops.py deleted file mode 100644 index 1c150089..00000000 --- a/src/eegprep/plugins/ICLabel/pop_viewprops.py +++ /dev/null @@ -1,207 +0,0 @@ -"""View channel or component property thumbnails.""" - -from __future__ import annotations - -from typing import Any - -import matplotlib.pyplot as plt -import numpy as np - -from eegprep.functions.guifunc.inputgui import inputgui -from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._property_browser import property_activity_browser -from eegprep.functions.popfunc._pop_utils import format_history_value -from eegprep.functions.popfunc._rejection import one_based_indices -from eegprep.functions.popfunc.pop_topoplot import pop_topoplot -from eegprep.plugins.ICLabel.pop_prop_extended import ( - classifier_default_index, - classifier_name_from_gui, - classifier_names, - component_count, - has_component_classifier, - pop_prop_extended, -) - - -PLOTS_PER_FIGURE = 35 - - -def pop_viewprops( - EEG: dict[str, Any], - typecomp: int | bool = 1, - chanorcomp: Any = None, - spec_opt: Any = None, - erp_opt: Any = None, - scroll_event: int | bool = 1, - classifier_name: str = "", - fig: Any = None, - *, - gui: bool | None = None, - renderer: Any | None = None, - plot: bool = True, - show_activity: bool = False, - reject_callback: Any | None = None, - return_com: bool = False, -): - """Render channel/component property overview figures and activity views.""" - if EEG is None: - return ([], "") if return_com else [] - if gui is None: - gui = chanorcomp is None - if gui: - result = inputgui(pop_viewprops_dialog_spec(EEG, typecomp), renderer=renderer) - if result is None: - return ([], "") if return_com else [] - chanorcomp = result.get("chanorcomp", "") - spec_opt = result.get("spec_opt", "") - erp_opt = result.get("erp_opt", "") - scroll_event = int(bool(result.get("scroll_event", True))) - if not int(bool(typecomp)): - classifier_name = classifier_name_from_gui(EEG, result.get("classifier_name", classifier_name)) - limit = int(EEG.get("nbchan", 0) or 0) if int(bool(typecomp)) else component_count(EEG) - indices = one_based_indices(chanorcomp, limit=limit, default_all=True) - figures = [] - if plot: - figures = _plot_props( - EEG, - int(bool(typecomp)), - indices, - spec_opt, - erp_opt, - classifier_name, - scroll_event, - show_activity, - reject_callback, - ) - command = _history_command(typecomp, indices, spec_opt, erp_opt, scroll_event, classifier_name) - return (figures, command) if return_com else figures - - -def pop_viewprops_dialog_spec(EEG: dict[str, Any], typecomp: int | bool = 1) -> DialogSpec: - """Return the EEGLAB-like ``pop_viewprops`` prompt.""" - is_channel = int(bool(typecomp)) - limit = int(EEG.get("nbchan", 0) or 0) if is_channel else component_count(EEG) - label = "Channel indices to plot:" if is_channel else "Component indices to plot:" - title = "View many chan or comp. properties -- pop_viewprops" - geometry = [(1.3, 1), (1.3, 1), (1.3, 1), (1,)] - controls = [ - ControlSpec("text", label), - ControlSpec("edit", tag="chanorcomp", value=f"1:{limit}"), - ControlSpec("text", "Spectral options (see spectopo() help):"), - ControlSpec("edit", tag="spec_opt", value=f"'freqrange', [2 {min(80, float(EEG.get('srate', 1)) / 2):g}]"), - ControlSpec("text", "Erpimage options (see erpimage() help):"), - ControlSpec("edit", tag="erp_opt", value=""), - ControlSpec( - "checkbox", - f"Draw events over scrolling {'channel' if is_channel else 'component'} activity", - tag="scroll_event", - value=True, - ), - ] - classifiers = classifier_names(EEG) if not is_channel else [] - if classifiers: - geometry.append((1,)) - controls.append( - ControlSpec( - "popupmenu", - "|".join(classifiers), - tag="classifier_name", - value=classifier_default_index(classifiers), - ) - ) - return DialogSpec( - title=title, - function_name="pop_viewprops", - eeglab_source="plugins/ICLabel/viewprops/pop_viewprops.m", - size=(600, 318 if classifiers else 288), - geometry=tuple(geometry), - controls=tuple(controls), - ) - - -def _plot_props( - EEG: dict[str, Any], - typecomp: int, - indices: list[int], - spec_opt: Any, - erp_opt: Any, - classifier_name: str, - scroll_event: int | bool, - show_activity: bool, - reject_callback: Any | None, -) -> list[Any]: - if not indices: - return [] - visible_indices = indices[:PLOTS_PER_FIGURE] - if not typecomp and has_component_classifier(EEG, classifier_name): - figure = pop_prop_extended( - EEG, - 0, - visible_indices, - None, - spec_opt, - erp_opt, - scroll_event, - classifier_name, - gui=False, - show_activity=show_activity, - reject_callback=reject_callback, - ) - return [figure] if figure is not None else [] - activity_views = [ - property_activity_browser(EEG, typecomp, index, scroll_event=scroll_event, show=show_activity) - for index in visible_indices - ] - if not typecomp: - figures = pop_topoplot( - EEG, 0, visible_indices, "View components properties - pop_viewprops()", [], 0, gui=False - ) - _attach_activity_views(figures, activity_views) - return figures - fig_obj, axes = plt.subplots(1, min(len(indices), PLOTS_PER_FIGURE), squeeze=False) - labels = _channel_labels(EEG) - for axis, index in zip(axes.ravel(), visible_indices): - axis.axis("off") - label = labels[index - 1] if index - 1 < len(labels) else str(index) - axis.text(0.5, 0.5, label, ha="center", va="center", fontsize=10) - axis.set_title(str(index)) - fig_obj.suptitle("View channels properties - pop_viewprops()") - fig_obj.tight_layout() - _attach_activity_views([fig_obj], activity_views) - return [fig_obj] - - -def _attach_activity_views(figures: list[Any], activity_views: list[Any]) -> None: - for figure in figures: - figure.eegprep_activity_views = activity_views - - -def _history_command( - typecomp: int | bool, - indices: list[int], - spec_opt: Any, - erp_opt: Any, - scroll_event: int | bool, - classifier_name: str, -) -> str: - args = [ - int(bool(typecomp)), - indices, - [] if spec_opt is None else spec_opt, - [] if erp_opt is None else erp_opt, - int(bool(scroll_event)), - classifier_name, - ] - return "pop_viewprops(EEG, " + ", ".join(format_history_value(arg, cell_for_sequence=None) for arg in args) + ");" - - -def _channel_labels(EEG: dict[str, Any]) -> list[str]: - labels = [] - chanlocs = EEG.get("chanlocs", []) - if chanlocs is None: - chanlocs = [] - if isinstance(chanlocs, np.ndarray): - chanlocs = chanlocs.tolist() - for index, chanloc in enumerate(chanlocs): - labels.append(str(chanloc.get("labels", index + 1)) if isinstance(chanloc, dict) else str(index + 1)) - return labels or [str(index + 1) for index in range(int(EEG.get("nbchan", 0) or 0))] diff --git a/src/eegprep/plugins/clean_rawdata/__init__.py b/src/eegprep/plugins/clean_rawdata/__init__.py deleted file mode 100644 index 997922e0..00000000 --- a/src/eegprep/plugins/clean_rawdata/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -"""clean_rawdata plugin ports.""" - -from __future__ import annotations - -from importlib import import_module -from typing import Any - -_LAZY_EXPORTS = { - "asr_calibrate": ("eegprep.plugins.clean_rawdata.asr_calibrate", "asr_calibrate"), - "asr_process": ("eegprep.plugins.clean_rawdata.asr_process", "asr_process"), - "clean_artifacts": ("eegprep.plugins.clean_rawdata.clean_artifacts", "clean_artifacts"), - "clean_asr": ("eegprep.plugins.clean_rawdata.clean_asr", "clean_asr"), - "clean_channels": ("eegprep.plugins.clean_rawdata.clean_channels", "clean_channels"), - "clean_channels_nolocs": ("eegprep.plugins.clean_rawdata.clean_channels_nolocs", "clean_channels_nolocs"), - "clean_drifts": ("eegprep.plugins.clean_rawdata.clean_drifts", "clean_drifts"), - "clean_flatlines": ("eegprep.plugins.clean_rawdata.clean_flatlines", "clean_flatlines"), - "clean_windows": ("eegprep.plugins.clean_rawdata.clean_windows", "clean_windows"), - "pop_clean_rawdata": ("eegprep.plugins.clean_rawdata.pop_clean_rawdata", "pop_clean_rawdata"), - "vis_artifacts": ("eegprep.plugins.clean_rawdata.vis_artifacts", "vis_artifacts"), - "vis_artifacts_diagnostics": ( - "eegprep.plugins.clean_rawdata.vis_artifacts", - "vis_artifacts_diagnostics", - ), -} - -__all__ = sorted(_LAZY_EXPORTS) - - -def __getattr__(name: str) -> Any: - try: - module_name, attr_name = _LAZY_EXPORTS[name] - except KeyError as exc: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc - value = getattr(import_module(module_name), attr_name) - globals()[name] = value - return value diff --git a/src/eegprep/plugins/clean_rawdata/asr_calibrate.py b/src/eegprep/plugins/clean_rawdata/asr_calibrate.py deleted file mode 100644 index 775aaddc..00000000 --- a/src/eegprep/plugins/clean_rawdata/asr_calibrate.py +++ /dev/null @@ -1,537 +0,0 @@ -"""Artifact Subspace Reconstruction (ASR) utilities.""" - -import logging -import math -import numpy as np -import scipy.signal -import scipy.linalg - -from ...functions.miscfunc.misc import canonicalize_signs, finite_matmul, round_mat -from .private.covariance import cov_mean, cov_shrinkage -from .private.stats import fit_eeg_distribution, geometric_median - -logger = logging.getLogger(__name__) - -# Sampling rates (Hz) for which a pre-computed spectral-shaping IIR filter is available. -_SUPPORTED_SRATES = frozenset({100, 128, 200, 250, 256, 300, 500, 512}) - - -def asr_calibrate( - X, - srate, - cutoff=None, - blocksize=None, - B=None, - A=None, - window_len=None, - window_overlap=None, - max_dropout_fraction=None, - min_clean_fraction=None, - maxmem=None, - useriemannian=None, - compatibility=None, -): - """Calibration function for the Artifact Subspace Reconstruction (ASR) method. - - State = asr_calibrate(Data, SamplingRate, Cutoff, BlockSize, FilterB, FilterA, WindowLength, WindowOverlap, MaxDropoutFraction, MinCleanFraction, MaxMemory) - - The input to this data is a multi-channel time series of calibration data. In typical uses the - calibration data is clean resting EEG data of ca. 1 minute duration (can also be longer). One can - also use on-task data if the fraction of artifact content is below the breakdown point of the - robust statistics used for estimation (50% theoretical, ~30% practical). If the data has a - proportion of more than 30-50% artifacts then bad time windows should be removed beforehand. This - data is used to estimate the thresholds that are used by the ASR processing function to identify - and remove artifact components. - - The calibration data must have been recorded for the same cap design from which data for cleanup - will be recorded, and ideally should be from the same session and same subject, but it is possible - to reuse the calibration data from a previous session and montage to the extent that the cap is - placed in the same location (where loss in accuracy is more or less proportional to the mismatch - in cap placement). - - The calibration data should have been high-pass filtered (for example at 0.5Hz or 1Hz using a - Butterworth IIR filter). - - Args: - X (np.ndarray): Calibration data [#channels x #samples]; *zero-mean* (e.g., high-pass filtered) and - reasonably clean EEG of not much less than 30 seconds length (this method is typically - used with 1 minute or more). - srate (float): Sampling rate of the data, in Hz. - - cutoff (float, optional): Standard deviation cutoff for rejection. Data portions whose variance is larger - than this threshold relative to the calibration data are considered missing - data and will be removed. The most aggressive value that can be used without - losing too much EEG is 5. Default: 5.0. - blocksize (int, optional): Block size for calculating the robust data covariance and thresholds, in samples; - allows to reduce the memory and time requirements of the robust estimators by this - factor. Default: 10. (Note: Memory-based dynamic calculation from MATLAB not implemented). - B (np.ndarray, optional): Numerator coefficients of an IIR filter used for shaping the spectrum for artifact statistics. - Default: Calculated using pre-computed values based on srate (approximating yulewalk). - A (np.ndarray, optional): Denominator coefficients of an IIR filter used for shaping the spectrum for artifact statistics. - Default: Calculated using pre-computed values based on srate (approximating yulewalk). - window_len (float, optional): Window length in seconds for checking artifact content. Default: 0.5. - window_overlap (float, optional): Window overlap fraction (0-1). Default: 0.66. - max_dropout_fraction (float, optional): Maximum fraction (0-1) of windows subject to dropouts. Default: 0.1. - min_clean_fraction (float, optional): Minimum fraction (0-1) of windows that must be clean. Default: 0.25. - maxmem (int, optional): Maximum memory in MB (for very large data/many channels). Default: 64. - useriemannian (str, optional): Option to use a Riemannian ASR variant. Can be set to 'calib' to use a Riemannian estimate - at calibration time; this make somewhat different statistical tradeoffs than the default, resulting in a potentially - different baseline rejection threshold; as a result it is suggested to visually check results and adjust - the cutoff as needed. Default: None (disabled). - compatibility (str, optional): MATLAB compatibility level. - * 'standard' (default) aims for 5 significant digits compatibility and may apply - slightly better numerical methods (e.g. using SOS filters for IIR filtering) - that are not available in stock MATLAB and therefore not used in the ASR - reference implementation. - * 'max' aims for maximum compatibility with MATLAB's results, aiming to match - results as closely as possible, perhaps trading off numerical robustness in - turn. Note the effects will mostly likely be miniscule and the MATLAB ASR - implementation is known to be highly robust. - - Returns - ------- - dict: State dictionary containing calibration results ('M', 'T') and filter parameters ('B', 'A', 'sos', 'iir_state') - needed for `asr_process`. - """ - # Ensure X is a numpy array and C x S - X = np.asarray(X, dtype=float) - if X.ndim != 2: - raise ValueError("Input data X must be a 2D array (channels x samples).") - C, S = X.shape - srate = float(srate) - - # Parameter defaults - if cutoff is None: - cutoff = 5.0 - if blocksize is None: - blocksize = 10 - if maxmem is None: - maxmem = 64 # in MB - if window_len is None: - window_len = 0.5 - if window_overlap is None: - window_overlap = 0.66 - if max_dropout_fraction is None: - max_dropout_fraction = 0.1 - if min_clean_fraction is None: - min_clean_fraction = 0.25 - if compatibility is None: - compatibility = 'standard' - - # there's no record of when or how this formula crept into the MATLAB code, but - # to match it, we'll have to use it here as well - blocksize = max(blocksize, math.ceil((C * C * S * 8 * 3 * 2) / (maxmem * (2**21)))) - - # Default IIR filter coefficients (approximating MATLAB's yulewalk defaults) - # Based on artifact_removal_legacy.py and asr_calibrate.m logic - if B is None or A is None: - sr_round = int(round_mat(srate)) - if sr_round == 100: - B = np.array( - [ - 0.9314233528641650, - -1.0023683814963549, - -0.4125359862018213, - 0.7631567476327510, - 0.4160430392910331, - -0.6549131038692215, - -0.0372583518046807, - 0.1916268458752655, - 0.0462411971592346, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0000000000000000, - -0.4544220180303844, - -1.0007038682936749, - 0.5374925521337940, - 0.4905013360991340, - -0.4861062879351137, - -0.1995986490699414, - 0.1830048420730026, - 0.0457678549234644, - ], - dtype=np.float64, - ) - elif sr_round == 128: - B = np.array( - [ - 1.1027301639165037, - -2.0025621813611867, - 0.8942119516481342, - 0.1549979524226999, - 0.0192366904488084, - 0.1782897770278735, - -0.5280306696498717, - 0.2913540603407520, - -0.0262209802526358, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0000000000000000, - -1.1042042046423233, - -0.3319558528606542, - 0.5802946221107337, - -0.0010360013915635, - 0.0382167091925086, - -0.2609928034425362, - 0.0298719057761086, - 0.0935044692959187, - ], - dtype=np.float64, - ) - elif sr_round == 200: - B = np.array( - [ - 1.4489483325802353, - -2.6692514764802775, - 2.0813970620731115, - -0.9736678877049534, - 0.1054605060352928, - -0.1889101692314626, - 0.6111331636592364, - -0.3616483013075088, - 0.1834313060776763, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0000000000000000, - -0.9913236099393967, - 0.3159563145469344, - -0.0708347481677557, - -0.0558793822071149, - -0.2539619026478943, - 0.2473056615251193, - -0.0420478437473110, - 0.0077455718334464, - ], - dtype=np.float64, - ) - elif sr_round == 250: - B = np.array( - [ - 1.7313331085426, - -4.168133532957, - 5.37379900844173, - -5.57212564343886, - 4.70122651316513, - -3.34208799655246, - 1.95045488724908, - -0.766909658912065, - 0.233281060974837, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0, - -1.6384949276666, - 1.73987814299055, - -1.83638657883456, - 1.3924177536798, - -0.953780426622198, - 0.505158779550745, - -0.159504514603055, - 0.0545278399847978, - ], - dtype=np.float64, - ) - elif sr_round == 256: - B = np.array( - [ - 1.7587013141770287, - -4.3267624394458641, - 5.7999880031015953, - -6.2396625463547508, - 5.3768079046882207, - -3.7938218893374835, - 2.1649108095226470, - -0.8591392569863763, - 0.2569361125627988, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0000000000000000, - -1.7008039639301735, - 1.9232830391058724, - -2.0826929726929797, - 1.5982638742557307, - -1.0735854183930011, - 0.5679719225652651, - -0.1886181499768189, - 0.0572954115997261, - ], - dtype=np.float64, - ) - elif sr_round == 300: - B = np.array( - [ - 1.9153920676433143, - -5.7748421104926795, - 9.1864764859103936, - -10.7350356619363630, - 9.6423672437729007, - -6.6181939699544277, - 3.4219421494177711, - -1.2622976569994351, - 0.2968423019363821, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0000000000000000, - -2.3143703322055491, - 3.2222567327379434, - -3.6030527704320621, - 2.9645154844073698, - -1.8842615840684735, - 0.9222455868758080, - -0.3103251703648485, - 0.0634586449896364, - ], - dtype=np.float64, - ) - elif sr_round == 500: - B = np.array( - [ - 2.3133520086975823, - -11.9471223009159130, - 29.1067166493384340, - -43.7550171007238190, - 44.3385767452216370, - -30.9965523846388000, - 14.6209883020737190, - -4.2743412400311449, - 0.5982553583777899, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0000000000000000, - -4.6893329084452580, - 10.5989986701080210, - -14.9691518101365230, - 14.3320358399731820, - -9.4924317069169977, - 4.2425899618982656, - -1.1715600975178280, - 0.1538048427717476, - ], - dtype=np.float64, - ) - elif sr_round == 512: - B = np.array( - [ - 2.3275475636130865, - -12.2166478485960430, - 30.1632789058248850, - -45.8009842020820410, - 46.7261263011068880, - -32.7796858196767220, - 15.4623349612560630, - -4.5019779685307473, - 0.6242733481676324, - ], - dtype=np.float64, - ) - A = np.array( - [ - 1.0000000000000000, - -4.7827378944258703, - 10.9780696236622980, - -15.6795187888195360, - 15.1281978667576310, - -10.0632079834518220, - 4.5014690636505614, - -1.2394100873286753, - 0.1614727510688058, - ], - dtype=np.float64, - ) - else: - # No precomputed spectral-shaping filter for this sampling rate. Degrading - # to a trivial difference filter would silently miscalibrate the ASR - # thresholds, so fail loudly instead (mirrors MATLAB's asr_calibrate:NoYulewalk). - raise ValueError( - f"No pre-computed ASR spectral filter for srate {srate} Hz " - f"(supported: {sorted(_SUPPORTED_SRATES)}). Resample the data to a " - "supported rate or pass explicit filter coefficients via B and A." - ) - - # Ensure data is finite - X[~np.isfinite(X)] = 0.0 - - # Apply the signal shaping filter based on compatibility mode - if compatibility == 'max': - # Maximum MATLAB compatibility: use B/A form with lfilter - # Initialize filter state to zeros (matching MATLAB's filter(..., [], 2)) - # For multi-channel data (C x S) filtering along axis=1, zi shape is (C, max(len(A), len(B)) - 1) - zi = np.zeros((C, max(len(A), len(B)) - 1)) - Xf, iir_state = scipy.signal.lfilter(B, A, X, axis=1, zi=zi) - sos = None # Not used in this mode - else: - # Standard mode: use second-order sections (SOS) for numerical stability - sos = scipy.signal.tf2sos(B, A) - # Need initial state per channel: shape (n_sections, n_channels, 2) - # (since the data are assumed to be zero-mean, use a zero state, as in MATLAB) - zi = np.zeros((sos.shape[0], C, 2)) - Xf, iir_state = scipy.signal.sosfilt(sos, X, axis=1, zi=zi) - - if np.any(~np.isfinite(Xf)): - raise RuntimeError( - 'The IIR filter diverged on your data. Please try using either ' - 'a more conservative filter or removing some bad sections/channels from the calibration data.' - ) - - # Calculate the sample covariance matrices U (averaged in blocks of blocksize successive samples) - # U will be shape (C, C, num_blocks) - logger.info("Calculating blockwise covariances...") - - # Determine the number of blocks - num_blocks = int(np.ceil(S / blocksize)) - U = np.zeros((C, C, num_blocks)) - block_starts = np.arange(0, S, blocksize) - - # Accumulate outer products in blocks for memory efficiency - for k in range(blocksize): - # Calculate indices for this step, avoiding going past the end - range_indices = np.minimum(block_starts + k, S - 1) - if range_indices.size == 0: - continue # Skip if no indices - - # Extract data for these indices - X_k = Xf[:, range_indices] - - # Calculate and accumulate outer products - outer_products = np.reshape(X_k, (C, 1, -1)) * np.reshape(X_k, (1, C, -1)) - - # Add to U, ensuring shape alignment - if outer_products.shape[2] < U.shape[2]: - U[:, :, : outer_products.shape[2]] += outer_products - else: - U += outer_products - - # Average the accumulated covariances - U /= blocksize - - # compute a robust average of the covariance matrices - med = None - if useriemannian in ('calib', 'all', True): - logger.info("Calculating Riemannian geometric median covariance...") - U = U.transpose(2, 0, 1) - # small amount of shrinkage to prevent singularities - U = cov_shrinkage(U, 1e-4, target='scaled-eye') - med = cov_mean(U, robust=True) - if med is None or np.any(np.isnan(med)): - if med is not None: - logger.warning( - "Riemannian geometric median calculation resulted in NaNs. Using standard geometric median as fallback." - ) - logger.info("Calculating robust geometric median covariance...") - med = geometric_median(U.reshape(C * C, -1).T) - if np.any(np.isnan(med)): - logger.warning("Geometric median calculation resulted in NaNs. Using standard median as fallback.") - med = np.median(U, axis=-1) - - # make sure median is reshaped back to matrix form - M_robust = np.reshape(med, (C, C)) - - # Get the mixing matrix M (matrix square root of the robust covariance) - M = scipy.linalg.sqrtm(np.real(M_robust)) - M = np.real(M) # Ensure M is real - - # ----- Calculate Thresholds ----- - # Window length for calculating thresholds - N = int(round_mat(window_len * srate)) - if S < N: - raise ValueError(f'Not enough calibration data. Need at least {N} samples, got {S}.') - - logger.info('Determining per-component thresholds...') - - # Eigendecomposition of M plus some massaging - # to ensure reproducibility across platforms - M = 0.5 * (M + M.T) # Ensure symmetry - D, V = np.linalg.eigh(M) # eigh returns sorted eigenvalues - V = canonicalize_signs(V) - - # Transform data into component space (using eigenvectors) - X_transformed = np.abs(finite_matmul(Xf.T, V)) # Shape: (S, C) - - # Calculate window indices for RMS calculation - step = N * (1.0 - window_overlap) - if step <= 0: - logger.warning("Window overlap >= 1, using step=1") - step = 1 - window_starts = round_mat(np.arange(0, S - N, step)).astype(int) - - if len(window_starts) <= 1: - raise ValueError(f'Not enough windows possible. Need length > {N}, got {S}.') - - # Create window indices matrix - window_indices = window_starts[:, None] + np.arange(N) - - # Initialize arrays for mu and sigma - mu = np.zeros(C) - sig = np.zeros(C) - - # Calculate thresholds for each component - for c in reversed(range(C)): - comp_data = X_transformed[:, c] ** 2 - - # Calculate RMS amplitude for each window - rms_windows = np.sqrt(np.mean(comp_data[window_indices], axis=1)) - - # Fit a distribution to the clean part - try: - mu_c, sig_c, _, _ = fit_eeg_distribution( - rms_windows, min_clean_fraction=min_clean_fraction, max_dropout_fraction=max_dropout_fraction - ) - mu[c] = mu_c - sig[c] = sig_c - except Exception as e: - logger.warning(f"Distribution fitting failed for component {c}: {e}") - mu[c] = np.nan - sig[c] = np.nan - - # Check for NaN values and provide warning - if np.any(np.isnan(mu)) or np.any(np.isnan(sig)): - logger.warning("NaN values in threshold calculation. Results may be unreliable.") - # Replace NaNs with reasonable values - mu = np.nan_to_num(mu, nan=np.nanmedian(mu) if np.any(~np.isnan(mu)) else 1.0) - sig = np.nan_to_num(sig, nan=np.nanmedian(sig) if np.any(~np.isnan(sig)) else 0.5) - - # Ensure sigma is non-negative - sig = np.maximum(sig, 0) - - # Calculate threshold matrix T - T = finite_matmul(np.diag(mu + cutoff * sig), V.T) - - logger.info('Thresholds calculation complete.') - - # Return the state dictionary - state = { - 'M': M, # Mixing matrix - 'T': T, # Threshold matrix - 'B': B, # Original filter coefficients (for reference) - 'A': A, - 'sos': sos, # SOS filter representation for processing (None if compatibility='max') - 'iir_state': iir_state, # Initial filter state - 'cov': None, # Initial covariance buffer (will be set in process) - 'carry': None, # Initial carry buffer (will be set in process) - 'last_R': None, # Initial reconstruction matrix (will be set in process) - 'last_trivial': True, # Initial trivial flag - 'useriemannian': useriemannian, # Riemannian ASR variant option - 'compatibility': compatibility, # Compatibility mode for IIR filtering - } - - return state diff --git a/src/eegprep/plugins/clean_rawdata/asr_process.py b/src/eegprep/plugins/clean_rawdata/asr_process.py deleted file mode 100644 index c77b0134..00000000 --- a/src/eegprep/plugins/clean_rawdata/asr_process.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Artifact Subspace Reconstruction (ASR) utilities.""" - -import logging - -import numpy as np -import scipy.signal - -from ...functions.miscfunc.misc import finite_matmul, finite_pinv, round_mat -from .private.sigproc import moving_average - -logger = logging.getLogger(__name__) - - -def asr_process( - data, srate, state, window_len=0.5, lookahead=None, step_size=32, max_dims=0.66, max_mem=None, use_gpu=False -): - """Process data using the Artifact Subspace Reconstruction (ASR) method. - - CleanedData, State = asr_process(Data, SamplingRate, State, WindowLength, LookAhead, StepSize, MaxDimensions, MaxMemory, UseGPU) - - This function is used to clean multi-channel signal using the ASR method. The required inputs are - the data matrix, the sampling rate of the data, and the filter state (as initialized by - asr_calibrate or from the previous call to asr_process). - - Args: - data (np.ndarray): Chunk of data to process [#channels x #samples]. Assumed to be - a continuation of previous data if 'state' is provided. - Data should be zero-mean (e.g., high-pass filtered). - srate (float): Sampling rate of the data in Hz. - state (dict): State dictionary from asr_calibrate or previous asr_process call. - Contains M, T, sos, iir_state, cov, carry, last_R, last_trivial. - window_len (float, optional): Length of the statistics window in seconds. Should not be much - longer than artifact time scale. Min samples: 1.5x channels. Default: 0.5. - lookahead (float, optional): Look-ahead amount in seconds (causes delay). Recommended: window_len/2. - Range [0, window_len/2]. Default: window_len/2. - step_size (int, optional): Update statistics every this many samples. Larger is faster. - Max: window_len * srate. Default: 32. - max_dims (float or int, optional): Maximum dimensions/fraction of dimensions to remove. - Default: 0.66 (fraction). - max_mem (int, optional): Maximum memory in MB for processing large chunks. Process in one block if None. - Default: None. - use_gpu (bool, optional): Whether to use GPU (not implemented). Default: False. - - Returns - ------- - tuple: (outdata, outstate) - outdata (np.ndarray): Cleaned data chunk (delayed by lookahead). - outstate (dict): Updated state dictionary for subsequent calls. - """ - # Check and sanitize data - data = np.asarray(data, dtype=float) - if data.ndim != 2: - raise ValueError("Input data must be a 2D array (channels x samples).") - C, S = data.shape - - if S == 0: - return data, state # Return empty data as is - - # Parameter handling - if lookahead is None: - lookahead = window_len / 2 - if max_mem is None: - # use at most half of available memory - import psutil - - max_mem = psutil.virtual_memory().free / 1024**2 / 2 - - # Ensure window length is adequate - window_len = max(window_len, 1.5 * C / srate) - - # Convert max_dims to actual number if given as fraction - if max_dims < 1: - max_dims_num = int(round_mat(C * max_dims)) - else: - max_dims_num = int(max_dims) - - # Number of samples in sliding window and lookahead - N = int(round_mat(window_len * srate)) - P = int(round_mat(lookahead * srate)) - - # Fix NaN and Inf values - data[~np.isfinite(data)] = 0 - - # Extract state variables - M = state['M'] # Mixing matrix - T = state['T'] # Threshold matrix - sos = state.get('sos') # SOS filter representation (None if compatibility='max') - b = state.get('B') # Filter numerator coefficients - a = state.get('A') # Filter denominator coefficients - compatibility = state.get('compatibility', 'standard') # Compatibility mode - iir_state = state.get('iir_state') # Filter state - carry = state.get('carry') # Carry buffer (previous lookahead data) - cov = state.get('cov') # Covariance state (MovAvgState or None) - # If cov is from an older run and is not a MovAvgState, reset it - if cov is not None and not hasattr(cov, 'buf'): - cov = None - last_R = state.get('last_R') # Last reconstruction matrix - last_trivial = state.get('last_trivial', True) # Was last step trivial (no artifacts) - - # Initialize prior filter state by extrapolating available data into the past - if carry is None: - ind = np.mod(np.arange(P + 1, 1, -1) - 1, S) - carry = 2 * data[:, [0]] - data[:, ind] - - # Prepend the carry buffer to the data - X = np.concatenate((carry, data), axis=1) - - # Calculate number of splits for memory management - - if max_mem * 1024 * 1024 - C * C * P * 8 * 3 < 0: - logger.warning( - "Memory too low, increasing it (rejection block size now " - "depends on available memory so it might not be fully reproducible)..." - ) - import psutil - - max_mem = psutil.virtual_memory().free / 1024**2 / 2 - if max_mem * 1024 * 1024 - C * C * P * 8 * 3 < 0: - raise RuntimeError('Not enough memory') - - # Calculate memory bytes needed (following reference implementation formula) - bytes_needed = C * C * S * 8 * 8 + C * C * 8 * S / step_size + C * S * 8 * 2 + S * 8 * 5 - - # Available memory in bytes (subtract fixed overhead) - mem_available = max_mem * 1024**2 - C * C * P * 8 * 3 - mem_available = max(mem_available, 1) # Ensure positive - - # Number of splits needed - splits = int(np.ceil(bytes_needed / mem_available)) - # Cap at reasonable value - splits = min(splits, 10000) - - if splits > 1: - logger.info(f'Cleaning data in {splits} blocks') - - # Process data in chunks - for k in range(splits): - # Calculate range for this chunk in the original data space - chunk_start = int(np.floor(k * S / splits)) - chunk_end = int(min(S, np.floor((k + 1) * S / splits))) - range_ = np.arange(chunk_start, chunk_end) - - if len(range_) == 0: - continue - - # Get spectrally shaped data for statistics computation (range shifted by lookahead) - Xraw = X[:, range_ + P] - - # Filter the data window based on compatibility mode - if compatibility == 'max': - # Maximum MATLAB compatibility: use B/A form with lfilter - Xfilt, iir_state = scipy.signal.lfilter(b, a, Xraw, axis=1, zi=iir_state) - else: - # Standard mode: use SOS form - Xfilt, iir_state = scipy.signal.sosfilt(sos, Xraw, axis=1, zi=iir_state) - - # Calculate per‑sample covariance vectors and compute the running mean - # covariance using the stateful `moving_average` implementation that - # replicates MATLAB's `moving_average` helper. This yields a smoothed - # covariance estimate for *every* sample and updates the internal - # circular buffer/state stored in `cov`. - Xcov_sample = np.reshape(np.reshape(Xfilt, (C, 1, -1)) * np.reshape(Xfilt, (1, C, -1)), (C * C, -1)) - - # Running mean over a window of N samples (along the last / time axis) - Xcov_filtered, cov = moving_average(Xcov_sample, N=N, axis=1, Z=cov, init=0) - - # Determine points at which to update the reconstruction matrix - update_at = np.arange(step_size, Xfilt.shape[1] + step_size - 1, step_size, dtype=int) - update_at = np.minimum(update_at, Xfilt.shape[1]) - - # If there is no previous R, initialize at first sample - if last_R is None: - update_at = np.insert(update_at, 0, 1) - last_R = np.eye(C) - - update_at -= 1 # prepare for 0-based indexing - - # Extract the covariance matrices at our update points (already - # averaged by the moving window) and reshape to C × C × #updates. - Xcov_matrices = np.reshape(Xcov_filtered[:, update_at], (C, C, len(update_at))) - - # Process each update point - last_n = -1 # MATLAB uses 1‑based indexing; align so first sample is included - for j, n in enumerate(update_at): - # Eigendecomposition to find potential artifact components - try: - D, V = np.linalg.eigh(Xcov_matrices[:, :, j]) - # Sort in ascending order (eigh already does this) - # D and V are already sorted in ascending order by eigh - except np.linalg.LinAlgError: - # Fallback if eigendecomposition fails - logger.warning(f"Eigendecomposition failed at update point {j}. Using identity matrix.") - D, V = np.ones(C), np.eye(C) - - # Determine which components to keep (variance below threshold or not admissible for rejection). - # No catch-all here: a shape/contract error in T/V must surface rather than silently - # disable artifact removal for this window. Genuine numerical-singularity cases are - # handled by the LinAlgError fallbacks for eigendecomposition and the pseudo-inverse. - thresholds = np.sum(finite_matmul(T, V) ** 2, axis=0) - keep = (D < thresholds) | (np.arange(1, C + 1) < (C - max_dims_num)) - trivial = np.all(keep) - - # Update the reconstruction matrix R - if not trivial: - try: - # Following reference implementation: - # Get V[:, keep] equivalent by multiplying V by a diagonal selection matrix - keep_mask = keep[np.newaxis, :] # Make column vector - A = finite_matmul(V.T, M) # V.T × M - masked_A_T = keep_mask * A.T # Zero out rows where keep is False - Q = masked_A_T.T # Back to original orientation - - # Calculate reconstruction matrix - Z = finite_pinv(Q) - R = np.real(finite_matmul(finite_matmul(M, Z), V.T)) - except np.linalg.LinAlgError: - logger.warning(f"Failed to calculate inverse at update point {j}. Using identity matrix.") - R = np.eye(C) - trivial = True - else: - R = np.eye(C) - - # Apply reconstruction to data - if not trivial or not last_trivial: - # Get subrange of data to process - subrange = range(last_n + 1, n + 1) - if len(subrange) > 0: - # Calculate blend coefficients (raised cosine) - blend = (1 - np.cos(np.pi * np.arange(1, len(subrange) + 1) / len(subrange))) / 2 - - # Extract data segment to process (from extended data X) - idx_in_X = range_[subrange] - segment = X[:, idx_in_X] - - # Apply blended reconstruction - X[:, idx_in_X] = blend * finite_matmul(R, segment) + (1 - blend) * finite_matmul(last_R, segment) - - # Update state for next iteration - last_n = n - last_R = R - last_trivial = trivial - - if splits > 1 and k % 10 == 0: - logger.debug(f'Processing block {k + 1}/{splits}') - - if splits > 1: - logger.info('Finished cleaning.') - - # Update the carry buffer for next call (last P samples) - new_carry = X[:, -P:] if X.shape[1] >= P else X - - # Return cleaned data (without the lookahead portion) - outdata = X[:, P : P + S] - - # Update state dictionary - outstate = { - 'M': M, - 'T': T, - 'sos': sos, - 'iir_state': iir_state, - 'cov': cov, - 'carry': new_carry, - 'last_R': last_R, - 'last_trivial': last_trivial, - # Include original filter coefficients and compatibility mode - 'B': b, - 'A': a, - 'compatibility': compatibility, - 'useriemannian': state.get('useriemannian'), - } - - return outdata, outstate diff --git a/src/eegprep/plugins/clean_rawdata/clean_artifacts.py b/src/eegprep/plugins/clean_rawdata/clean_artifacts.py deleted file mode 100644 index 6b0e6cd3..00000000 --- a/src/eegprep/plugins/clean_rawdata/clean_artifacts.py +++ /dev/null @@ -1,314 +0,0 @@ -"""EEG artifact cleaning functions.""" - -import copy -import logging -from typing import Any, Dict, Optional, Sequence, Tuple, Union - -import numpy as np - -# Local imports from the eegprep package -from .clean_flatlines import clean_flatlines -from .clean_drifts import clean_drifts -from .clean_channels import clean_channels -from .clean_channels_nolocs import clean_channels_nolocs -from .clean_asr import clean_asr -from .clean_windows import clean_windows -from .private.masks import mask_to_intervals -from ...functions.miscfunc.misc import round_mat -from ...functions.popfunc.eeg_eegrej import eeg_eegrej - - -logger = logging.getLogger(__name__) -_DISTANCE_MODES = { - 'euclidian': None, - 'euclidean': None, - 'riemannian': 'calib', -} - -# ----------------------------------------------------------------------------- -# Public API -# ----------------------------------------------------------------------------- - - -def clean_artifacts( - EEG: Dict[str, Any], - # Core parameters - ChannelCriterion: Union[float, str, None] = 0.8, - LineNoiseCriterion: Union[float, str, None] = 4.0, - BurstCriterion: Union[float, str, None] = 5.0, - WindowCriterion: Union[float, str, None] = 0.25, - Highpass: Union[Tuple[float, float], str, None] = (0.25, 0.75), - # Detail parameters - ChannelCriterionMaxBadTime: float = 0.5, - BurstCriterionRefMaxBadChns: Union[float, str, None] = 0.075, - BurstCriterionRefTolerances: Union[Tuple[float, float], str, None] = (-np.inf, 5.5), - BurstRejection: bool = False, - WindowCriterionTolerances: Union[Tuple[float, float], str, None] = (-np.inf, 7), - FlatlineCriterion: Union[float, str, None] = 5.0, - NumSamples: int = 50, - SubsetSize: float = 0.25, - NoLocsChannelCriterion: float = 0.45, - NoLocsChannelCriterionExcluded: float = 0.1, - MaxMem: int = 64, - Distance: str = 'euclidian', - # Misc. - Channels: Optional[Sequence[str]] = None, - Channels_ignore: Optional[Sequence[str]] = None, - availableRAM_GB: Optional[float] = None, -) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], np.ndarray]: - """All-in-one artifact removal, port of MATLAB clean_artifacts. - - Removes flatline channels, low-frequency drifts, noisy channels, short-time bursts, - and irrecoverable windows in sequence. Core parameters can be passed as None or 'off' - to use defaults or disable stages. - - Parameters - ---------- - EEG : dict - Raw continuous EEG dataset dict (must include 'data', 'srate', 'chanlocs', etc.). - ChannelCriterion : float or 'off' - Minimum channel correlation threshold for channel cleaning; channels below - this value are considered bad. Pass 'off' to skip channel criterion. Default 0.8. - LineNoiseCriterion : float or 'off' - Z-score threshold for line-noise contamination; channels exceeding this are - considered bad. 'off' disables line-noise check. Default 4.0. - BurstCriterion : float or 'off' - ASR standard-deviation cutoff for high-amplitude bursts; values above this - relative to calibration data are repaired (or removed if BurstRejection='on'). - 'off' skips ASR. Default 5.0. - WindowCriterion : float or 'off' - Fraction (0-1) or count of channels allowed to be bad per window; windows with - more bad channels are removed. 'off' disables final window removal. Default 0.25. - Highpass : tuple(float, float) or 'off' - Transition band [low, high] in Hz for initial high-pass filtering. 'off' skips - drift removal. Default (0.25, 0.75). - ChannelCriterionMaxBadTime : float - Maximum tolerated time (seconds or fraction of recording) a channel may be flagged - bad before being removed. Default 0.5. - BurstCriterionRefMaxBadChns : float or 'off' - Maximum fraction of bad channels tolerated when selecting calibration data for ASR. - 'off' uses all data for calibration. Default 0.075. - BurstCriterionRefTolerances : tuple(float, float) or 'off' - Power Z-score tolerances for selecting calibration windows in ASR. 'off' uses - all data. Default (-inf, 5.5). - BurstRejection : bool - 'on' to reject (drop) burst segments instead of reconstructing with ASR, - 'off' to apply ASR repair. Default 'off'. - WindowCriterionTolerances : tuple(float, float) or 'off' - Power Z-score bounds for final window removal. 'off' disables this stage. - Default (-inf, 7). - FlatlineCriterion : float or 'off' - Maximum flatline duration in seconds; channels exceeding this are removed. - 'off' disables flatline removal. Default 5.0. - NumSamples : int - Number of RANSAC samples for channel cleaning. Default 50. - SubsetSize : float - Size of channel subsets for RANSAC, as fraction (0-1) or count. Default 0.25. - NoLocsChannelCriterion : float - Correlation threshold for fallback channel cleaning when no channel locations. - Default 0.45. - NoLocsChannelCriterionExcluded : float - Fraction of channels excluded when assessing correlation in nolocs cleaning. - Default 0.1. - MaxMem : int - Maximum memory in MB for ASR processing. Default 64. - Distance : str - Distance metric for ASR calibration ('euclidian' or 'riemannian'). The - Riemannian path uses EEGPrep's calibration-time estimate; full - Riemannian ASR processing is not ported. - Channels : sequence of str or None - List of channel labels to include before cleaning (pop_select). The - returned dataset contains only these channels; channels outside this list - are dropped and not re-inserted. Default None. - Channels_ignore : sequence of str or None - List of channel labels to exclude before cleaning. The excluded channels - are dropped from the returned dataset and not re-inserted. Default None. - availableRAM_GB : float or None - Available system RAM in GB to adjust MaxMem. Default None. - - Returns - ------- - EEG : dict - Final cleaned EEG dataset. - HP : dict - EEG dataset after initial high-pass (drift removal). - BUR : dict - EEG dataset after ASR burst repair (before final window removal). - removed_channels : ndarray of bool - Mask indicating which channels were removed during cleaning. - """ - # ------------------------------------------------------------------ - # Basic argument sanity / aliases - # ------------------------------------------------------------------ - if availableRAM_GB is not None and not np.isnan(availableRAM_GB): - MaxMem = int(round_mat(availableRAM_GB * 1000)) - - if Channels is not None and Channels_ignore is not None and len(Channels) and len(Channels_ignore): - raise ValueError('"Channels" and "Channels_ignore" are mutually exclusive – supply at most one.') - - distance = str(Distance).strip().lower() - if distance not in _DISTANCE_MODES: - raise ValueError("Distance must be 'euclidian', 'euclidean', or 'riemannian'") - - EEG = copy.deepcopy(EEG) - - # Ensure some obligatory fields exist in the structure (MATLAB code assumes) - if 'etc' not in EEG: - EEG['etc'] = {} - - # ------------------------------------------------------------------ - # Optional: restrict to / ignore certain channels - # ------------------------------------------------------------------ - if Channels is not None and len(Channels): - # Attempt pop_select based on labels; the manual fallback only covers the - # documented case where pop_select is unavailable (ImportError). Errors - # raised inside pop_select itself must surface, not be masked. - try: - from eegprep import pop_select - - EEG = pop_select(EEG, channel=list(Channels)) - except ImportError: - lbl_to_idx = {ch['labels']: idx for idx, ch in enumerate(EEG['chanlocs'])} - keep_idx = [lbl_to_idx[lbl] for lbl in Channels if lbl in lbl_to_idx] - EEG['data'] = EEG['data'][keep_idx, :] - EEG['chanlocs'] = [EEG['chanlocs'][i] for i in keep_idx] - EEG['nbchan'] = len(keep_idx) - elif Channels_ignore is not None and len(Channels_ignore): - try: - from eegprep import pop_select - - EEG = pop_select(EEG, nochannel=list(Channels_ignore)) - except ImportError: - lbl_to_idx = {ch['labels']: idx for idx, ch in enumerate(EEG['chanlocs'])} - drop_idx_set = {lbl_to_idx[lbl] for lbl in Channels_ignore if lbl in lbl_to_idx} - keep_idx = [i for i in range(len(EEG['chanlocs'])) if i not in drop_idx_set] - EEG['data'] = EEG['data'][keep_idx, :] - EEG['chanlocs'] = [EEG['chanlocs'][i] for i in keep_idx] - EEG['nbchan'] = len(keep_idx) - - # ------------------------------------------------------------------ - # 1) Flat‑line channel removal - # ------------------------------------------------------------------ - if FlatlineCriterion not in (None, 'off'): - logger.info('Detecting flat line channels...') - EEG = clean_flatlines(EEG, max_flatline_duration=float(FlatlineCriterion)) - - # ------------------------------------------------------------------ - # 2) High‑pass filtering - # ------------------------------------------------------------------ - if Highpass not in (None, 'off'): - if not isinstance(Highpass, (tuple, list)) or len(Highpass) != 2: - raise ValueError('Highpass must be a (low, high) tuple or None/"off".') - logger.info('Applying high‑pass filter...') - EEG = clean_drifts(EEG, tuple(Highpass)) - # Keep a point-in-time snapshot after HP for optional return. Deep-copy so - # later stages that mutate EEG['etc'] (channel/sample masks) do not bleed - # into the returned high-pass dataset. - HP = copy.deepcopy(EEG) - - # ------------------------------------------------------------------ - # 3) Channel cleaning (noisy / disconnected) - # ------------------------------------------------------------------ - removed_channels = np.zeros(EEG['nbchan'], dtype=bool) - if ChannelCriterion not in (None, 'off') or LineNoiseCriterion not in (None, 'off'): - chancorr_crit = 0.0 if ChannelCriterion in (None, 'off') else float(ChannelCriterion) - line_crit = 100.0 if LineNoiseCriterion in (None, 'off') else float(LineNoiseCriterion) - try: - # Match MATLAB clean_artifacts.m line 254: - # clean_channels(EEG,chancorr_crit,line_crit,[],channel_crit_maxbad_time, num_samples) - # MATLAB passes [] for window_len (uses default 5) and doesn't pass subset_size (uses default 0.25) - # Python: not passing window_len uses default 5, and we pass subset_size explicitly (default 0.25) - EEG = clean_channels( - EEG, - corr_threshold=chancorr_crit, - noise_threshold=line_crit, - # window_len not passed - uses default 5 (MATLAB passes [] which also uses default 5) - max_broken_time=float(ChannelCriterionMaxBadTime), - num_samples=int(NumSamples), - subset_size=SubsetSize, # Default 0.25, matches MATLAB default when not passed - ) - # clean_channels only writes clean_channel_mask when it removes channels; - # an absent mask means nothing was removed, so keep the all-False default. - mask = EEG.get('etc', {}).get('clean_channel_mask') - if mask is not None: - removed_channels = ~mask - except ValueError as e: - # Only the missing-channel-locations case warrants the no-locs fallback; - # any other ValueError is a genuine failure and must propagate. - if 'location' not in str(e).lower(): - raise - logger.warning(f'clean_channels lacks usable locations ({e}); falling back to clean_channels_nolocs.') - EEG, removed_channels = clean_channels_nolocs( - EEG, - min_corr=float(NoLocsChannelCriterion), - ignored_quantile=float(NoLocsChannelCriterionExcluded), - window_len=2.0, - max_broken_time=float(ChannelCriterionMaxBadTime), - ) - - # ------------------------------------------------------------------ - # 4) Burst repair via ASR - # ------------------------------------------------------------------ - BUR = EEG # default in case ASR is skipped - if BurstCriterion not in (None, 'off'): - logger.info('Applying ASR burst repair...') - # Snapshot the pre-repair data to compare against the ASR-repaired - # result; clean_asr returns a fresh dataset (BUR) and leaves EEG intact. - original_data = EEG['data'].copy() if BurstRejection else None - useriemannian = _DISTANCE_MODES[distance] - BUR = clean_asr( - EEG, - cutoff=float(BurstCriterion), - ref_maxbadchannels=BurstCriterionRefMaxBadChns, - ref_tolerances=BurstCriterionRefTolerances, - use_gpu=False, - useriemannian=useriemannian, - maxmem=int(MaxMem), - ) - - if BurstRejection: - # Determine unchanged samples: compare the pre-repair snapshot with - # the ASR-repaired data returned in BUR. - sample_mask = np.sum(np.abs(original_data - BUR['data']), axis=0) < 1e-8 - del original_data - # Convert retained samples to inclusive zero-based intervals. - retain_intervals = mask_to_intervals(sample_mask, value=True) - 1 - - # Remove very short intervals < 5 samples - if retain_intervals.size: - lengths = retain_intervals[:, 1] - retain_intervals[:, 0] - small = lengths < 5 - for s, e in retain_intervals[small]: - sample_mask[s : e + 1] = False - retain_intervals = retain_intervals[~small] - - # Reject bad periods from the ASR-repaired dataset (BUR). - EEG = BUR - rejected_intervals = mask_to_intervals(sample_mask, value=False) - if rejected_intervals.size: - EEG = eeg_eegrej(EEG, rejected_intervals) - - # Update mask in EEG.etc - EEG['etc']['clean_sample_mask'] = sample_mask - else: - EEG = BUR - - # ------------------------------------------------------------------ - # 5) Post‑clean windows stage - # ------------------------------------------------------------------ - if WindowCriterion not in (None, 'off') and WindowCriterionTolerances not in (None, 'off'): - logger.info('Final post‑processing – removing irrecoverable windows...') - EEG, _ = clean_windows( - EEG, - max_bad_channels=float(WindowCriterion), - zthresholds=WindowCriterionTolerances, - ) - - logger.info('Use vis_artifacts to compare the cleaned data to the original.') - - # When Channels/Channels_ignore restrict the dataset, the returned EEG holds - # only the cleaned subset; excluded channels are not re-inserted (unlike the - # MATLAB reference). Callers that need the ignored channels back must merge - # them manually. This is documented on the Channels/Channels_ignore parameters. - - return EEG, HP, BUR, removed_channels diff --git a/src/eegprep/plugins/clean_rawdata/clean_asr.py b/src/eegprep/plugins/clean_rawdata/clean_asr.py deleted file mode 100644 index 76470820..00000000 --- a/src/eegprep/plugins/clean_rawdata/clean_asr.py +++ /dev/null @@ -1,233 +0,0 @@ -"""EEG ASR (Artifact Subspace Reconstruction) cleaning utilities. - -This module provides functions for running the Artifact Subspace Reconstruction method -on EEG data to remove artifacts. -""" - -import logging -from typing import Dict, Any, Union, Tuple, Optional -from copy import deepcopy - -import math -import numpy as np - -# Assuming these utilities exist and are correctly ported/placed -from .asr_calibrate import asr_calibrate -from .asr_process import asr_process -from .clean_windows import clean_windows -from ...functions.miscfunc.misc import round_mat - -logger = logging.getLogger(__name__) - - -def clean_asr( - EEG: Dict[str, Any], - cutoff: float = 5.0, - window_len: Optional[float] = None, - step_size: Optional[int] = None, - max_dims: float = 0.66, - ref_maxbadchannels: Union[float, str, None, np.ndarray] = 0.075, - ref_tolerances: Union[Tuple[float, float], str, None] = (-3.5, 5.5), - ref_wndlen: Union[float, str, None] = 1.0, - use_gpu: bool = False, - useriemannian: Optional[str] = None, - maxmem: Optional[int] = 64, -) -> Dict[str, Any]: - """Run the Artifact Subspace Reconstruction (ASR) method on EEG data. - - This is an automated artifact rejection function that ensures that the data - contains no events that have abnormally strong power; the subspaces on which - those events occur are reconstructed (interpolated) based on the rest of the - EEG signal during these time periods. - - Args: - EEG (Dict[str, Any]): EEG data structure. Expected fields: - 'data' (np.ndarray): Channels x Samples matrix. - 'srate' (float): Sampling rate in Hz. - 'nbchan' (int): Number of channels. - It's assumed the data is zero-mean (e.g., high-pass filtered). - cutoff (float, optional): Standard deviation cutoff for rejection. Data portions whose variance - is larger than this threshold relative to the calibration data are - considered artifactual and removed. Aggressive: 3, Default: 5, Conservative: 20. - window_len (float, optional): Length of the statistics window in seconds. Should not be much longer - than artifact timescale. Samples in window should be >= 1.5x channels. - Default: max(0.5, 1.5 * nbchan / srate). - step_size (int, optional): Step size for processing in samples. Reconstruction matrix updated every - `step_size` samples. If None, defaults to window_len / 2 samples. - max_dims (float, optional): Maximum dimensionality/fraction of dimensions to reconstruct. Default: 0.66. - ref_maxbadchannels (Union[float, str, np.ndarray], optional): Parameter for automatic calibration data selection. - float: Max fraction (0-1) of bad channels tolerated in a window for it to be used as calibration data. Lower is more aggressive (e.g., 0.05). Default: 0.075. - 'off': Use all data for calibration. Assumes artifact contamination < ~30-50%. - np.ndarray: Directly provides the calibration data (channels x samples). - ref_tolerances (Union[Tuple[float, float], str], optional): Power tolerances (lower, upper) in SDs from robust EEG power - for a channel to be considered 'bad' during calibration data selection. Default: (-3.5, 5.5). Use 'off' to disable. - ref_wndlen (Union[float, str], optional): Window length in seconds for calibration data selection granularity. Default: 1.0. Use 'off' to disable. - use_gpu (bool, optional): Whether to try using GPU (requires compatible hardware and libraries, currently ignored). Default: False. - useriemannian (str, optional): Option to use a Riemannian ASR variant. Set to 'calib' to use a Riemannian estimate - at calibration time. Full Riemannian ASR processing is not ported in EEGPrep. Default: None (disabled). - maxmem (Optional[int], optional): Maximum memory in MB (passed to asr_calibrate/process, but chunking based on it is not implemented in Python port). Default: 64. - - Returns - ------- - Dict[str, Any] : The EEG dictionary with the 'data' field containing the cleaned data. - - Raises - ------ - ImportError : If automatic calibration data selection is needed (`ref_maxbadchannels` is float) but `clean_windows` cannot be imported. - ValueError : If input arguments are invalid, full Riemannian ASR processing is requested, or calibration fails critically. - """ - if 'data' not in EEG or 'srate' not in EEG or 'nbchan' not in EEG: - raise ValueError("EEG dictionary must contain 'data', 'srate', and 'nbchan'.") - useriemannian = _normalise_useriemannian(useriemannian) - - # Operate on a copy so the caller's data array (and dict) are never mutated; - # asr_calibrate zeroes non-finite samples in place on whatever array it receives. - data = np.array(EEG['data'], dtype=np.float64, copy=True) - srate = float(EEG['srate']) - nbchan = int(EEG['nbchan']) - C, S = data.shape - - if C != nbchan: - logger.warning(f"Mismatch between EEG['nbchan'] ({nbchan}) and EEG['data'].shape[0] ({C}). Using shape[0].") - nbchan = C # Use the actual dimension from data - - # --- Handle Defaults --- - if window_len is None: - window_len = max(0.5, 1.5 * nbchan / srate) - - # --- Ensure Data Type --- - # Already done with np.asarray above - - # --- Determine Reference/Calibration Data --- - ref_section_data = None - if ( - isinstance(ref_maxbadchannels, (int, float)) - and isinstance(ref_tolerances, (tuple, list)) - and isinstance(ref_wndlen, (int, float)) - ): - logger.info('Finding a clean section of the data for calibration...') - try: - # clean_windows is assumed to return the selected data array (C x S_clean) - # It needs the EEG dict structure, similar to other clean_* funcs - temp_EEG_for_cleanwin = deepcopy(EEG) - temp_EEG_for_cleanwin['data'] = data # ensure it has the float64 data - cleaned_EEG, _ = clean_windows(temp_EEG_for_cleanwin, ref_maxbadchannels, ref_tolerances, ref_wndlen) - ref_section_data = np.asarray(cleaned_EEG['data'], dtype=np.float64) - if ref_section_data.size == 0 or ref_section_data.shape[1] == 0: - logger.warning("clean_windows returned no data. Falling back to using all data for calibration.") - ref_section_data = data - elif ref_section_data.shape[1] < 64: - logger.warning( - "clean_windows returned insufficient data. Falling back to using all data for calibration." - ) - ref_section_data = data - except ValueError as e: - # clean_windows raises ValueError for expected calibration-data problems - # (empty data, window too small, not enough data for one window). Only - # those warrant the all-data fallback; unexpected exceptions propagate so - # genuine bugs are not masked as silently weaker ASR calibration. - logger.warning( - f"Could not automatically identify clean calibration data ({e}). " - "Falling back to using the entire data for calibration." - ) - ref_section_data = data - elif (isinstance(ref_maxbadchannels, str) and ref_maxbadchannels.lower() == 'off') or ref_maxbadchannels is None: - logger.info(f"Using the entire data for calibration ('ref_maxbadchannels' set to {ref_maxbadchannels!r}).") - ref_section_data = data - elif (isinstance(ref_tolerances, str) and ref_tolerances.lower() == 'off') or ref_tolerances is None: - logger.info(f"Using the entire data for calibration ('ref_tolerances' set to {ref_tolerances!r}).") - ref_section_data = data - elif (isinstance(ref_wndlen, str) and ref_wndlen.lower() == 'off') or ref_wndlen is None: - logger.info(f"Using the entire data for calibration ('ref_wndlen' set to {ref_wndlen!r}).") - ref_section_data = data - elif isinstance(ref_maxbadchannels, np.ndarray): - logger.info("Using user-supplied data array for calibration.") - ref_section_data = np.asarray(ref_maxbadchannels, dtype=np.float64) - if ref_section_data.ndim != 2 or ref_section_data.shape[0] != C: - raise ValueError(f"User-supplied calibration data must be a 2D array with shape ({C}, n_samples).") - else: - raise ValueError( - f"Unsupported value or type for 'ref_maxbadchannels': {ref_maxbadchannels}. Must be float, None/'off', or numpy array." - ) - - # --- Calibrate ASR --- - logger.info('Estimating ASR calibration statistics...') - # The Python asr_calibrate uses its own defaults for blocksize, filters, etc. - # We only pass the core parameters specified in the clean_asr call signature. - try: - state = asr_calibrate(ref_section_data, srate, cutoff=cutoff, maxmem=maxmem, useriemannian=useriemannian) - except ValueError as e: - # Catch specific errors like not enough calibration data - raise ValueError(f"ASR calibration failed: {e}") - # except Exception as e: - # # Catch unexpected errors during calibration - # logger.exception("An unexpected error occurred during ASR calibration.") - # raise RuntimeError(f"ASR calibration failed unexpectedly: {e}") - - del ref_section_data # Free memory - - # --- Prepare for Processing --- - if step_size is None: - step_size = int(math.floor(srate * window_len / 2)) # Samples - - # --- Extrapolate Signal End --- - # Required because asr_process needs lookahead data beyond the signal end - # Based on: sig = [signal.data bsxfun(@minus,2*signal.data(:,end),signal.data(:,(end-1):-1:end-round(windowlen/2*signal.srate)))]; - N_extrap = int(round_mat(window_len / 2 * srate)) - if N_extrap > 0: - # Calculate indices for reflection, handling edge case where N_extrap >= S-1 - extrap_len = min(N_extrap, S - 1 if S > 1 else 0) - if extrap_len > 0: - # Indices from second-to-last sample back 'extrap_len' steps - extrap_indices = np.arange(S - 2, S - extrap_len - 2, -1) - # Reflect around the last sample: 2*last_sample - samples_before_last - extrap_part = 2 * data[:, [-1]] - data[:, extrap_indices] - sig = np.concatenate((data, extrap_part), axis=1) - else: # Not enough data to extrapolate - sig = data - else: # No extrapolation needed - sig = data - - # --- Process Signal using ASR --- - logger.info('Applying ASR processing...') - lookahead_sec = window_len / 2.0 # asr_process expects lookahead in seconds - outdata, _ = asr_process( - sig, - srate, - state, - window_len=window_len, - lookahead=lookahead_sec, - step_size=step_size, - max_dims=max_dims, - max_mem=maxmem, - use_gpu=use_gpu, # Passed but ignored in current Python port - ) - - # --- Finalize --- - # shift signal content back (to compensate for processing delay) - outdata = outdata[:, :S] - EEG = deepcopy(EEG) - EEG['data'] = outdata - logger.info('ASR cleaning finished.') - - return EEG - - -def _normalise_useriemannian(value: Optional[str]) -> Optional[str]: - if value is None: - return None - if isinstance(value, bool): - if not value: - return None - raise ValueError("clean_asr supports useriemannian='calib' only; full Riemannian ASR processing is not ported.") - if isinstance(value, str): - lower = value.strip().lower() - if lower in {"", "0", "false", "no", "none", "off"}: - return None - if lower in {"calib", "calibration"}: - return "calib" - if lower in {"1", "true", "yes", "on", "all", "process", "riemannian"}: - raise ValueError( - "clean_asr supports useriemannian='calib' only; full Riemannian ASR processing is not ported." - ) - raise ValueError("clean_asr useriemannian must be None, 'off', or 'calib'.") diff --git a/src/eegprep/plugins/clean_rawdata/clean_channels.py b/src/eegprep/plugins/clean_rawdata/clean_channels.py deleted file mode 100644 index 30b05bb7..00000000 --- a/src/eegprep/plugins/clean_rawdata/clean_channels.py +++ /dev/null @@ -1,167 +0,0 @@ -"""EEG channel cleaning utilities.""" - -import logging -from typing import Any, Dict - -import numpy as np - -from eegprep.plugins.firfilt.design import design_fir - -from ...functions.miscfunc.misc import finite_matmul, round_mat -from .private.channel_removal import remove_channels_without_pop_select, update_clean_channel_mask -from .private.ransac import calc_projector -from .private.sigproc import filtfilt_fast -from .private.stats import mad - -logger = logging.getLogger(__name__) - - -def clean_channels( - EEG: Dict[str, Any], - corr_threshold: float = 0.8, - noise_threshold: float = 5.0, - window_len: float = 5, - max_broken_time: float = 0.4, - num_samples: int = 50, - subset_size: float = 0.25, -) -> Dict[str, Any]: - """Remove channels with problematic data from a continuous data set. - - This is an automated artifact rejection function which ensures that the data contains no channels - that record only noise for extended periods of time. If channels with control signals are - contained in the data these are usually also removed. The criterion is based on correlation: if a - channel has lower correlation to its robust estimate (based on other channels) than a given threshold - for a minimum period of time (or percentage of the recording), it will be removed. - - Args: - EEG: Continuous data set, assumed to be appropriately high-passed - (e.g. >0.5Hz or with a 0.5Hz - 2.0Hz transition band). - corr_threshold: Correlation threshold. If a channel is correlated at - less than this value to its robust estimate (based on other channels), - it is considered abnormal in the given time window. - noise_threshold: If a channel has more (high-frequency) noise relative to its signal - than this value, in standard deviations from the channel population mean, - it is considered abnormal. - window_len: Length of the windows (in seconds) for which correlation is computed; ideally - short enough to reasonably capture periods of global artifacts or intermittent - sensor dropouts, but not shorter (for statistical reasons). - max_broken_time: Maximum time (either in seconds or as fraction of the recording) - during which a channel is allowed to have artifacts. Reasonable range: - 0.1 (very aggressive) to 0.6 very lax). - num_samples: Number of samples generated for a RANSAC reconstruction. This is the - number of samples to generate in the random sampling consensus process. The larger - this value, the more robust but also slower the processing will be. - subset_size: Subset size. This is the size of the channel subsets to use - for robust reconstruction, as a number or fraction of the total number - of channels. - - Returns - ------- - EEG : data set with bad channels removed - """ - EEG['data'] = np.asarray(EEG['data'], dtype=np.float64) - C, S = EEG['data'].shape - Fs = EEG['srate'] - - # handle fractions or absolute values - if subset_size >= 1: - subset_size = int(subset_size) - else: - subset_size = int(round_mat(C * subset_size)) - if max_broken_time < 1: - max_broken_time = S * max_broken_time - else: - max_broken_time = round_mat(Fs) * max_broken_time - - window_len = int(window_len * round_mat(Fs)) - wnd = np.arange(int(window_len)) - offsets = np.arange(0, S - window_len, window_len, dtype=int) - W = len(offsets) - - logger.info('Scanning for bad channels...') - - if Fs > 100: - # remove signal content above 50Hz - B = design_fir(100, 2 * np.array([0, 45, 50, Fs / 2]) / Fs, [1, 1, 0, 0]) - X = np.zeros((S, C)) - for c in range(C): - X[:, c] = filtfilt_fast(B, 1, EEG['data'][c, :]) - - # determine z-scored level of EM noise-to-signal ratio for each channel - noisiness = mad(EEG['data'].T - X) / mad(X) - znoise = (noisiness - np.median(noisiness)) / (mad(noisiness) * 1.4826) - - # trim channels based on that - noise_mask = znoise > noise_threshold - else: - X = EEG['data'].T - noise_mask = np.zeros(C, dtype=bool) # transpose added in MATLAB comment - - # get the matrix of all channel locations [3xN] - xyz = [[ch.get(coord, np.nan) for ch in EEG['chanlocs']] for coord in ['X', 'Y', 'Z']] - xyz = [[x if not (isinstance(x, np.ndarray) and x.size == 0) else np.nan for x in xyz_sub] for xyz_sub in xyz] - xyz = np.asarray([np.asarray([np.nan if x is None else x for x in row], dtype=float) for row in xyz]) - if np.mean(np.any(np.isnan(xyz), axis=0)) > 0.5: - raise ValueError('To use this function most of your channels should have X,Y,Z location measurements.') - usable_channels = np.where(~np.any(np.isnan(xyz), axis=0))[0] - - locs = xyz[:, usable_channels].T - X = np.asarray(X[:, usable_channels]) - - # replicate MATLAB's default randseed, for exact compatibility - stream = np.random.RandomState(5489) - P = np.asarray(calc_projector(locs, num_samples, subset_size, stream=stream)) - corrs = np.zeros((len(usable_channels), W)) - - # calculate each channel's correlation to its RANSAC reconstruction for each window - time_passed_list = np.zeros(W) - for o in range(W): - import time - - start_time = time.time() - - XX = X[offsets[o] + wnd, :] - YY = np.sort(np.reshape(finite_matmul(XX, P).T, (num_samples, -1)), axis=0) - YY = np.reshape(YY[int(round_mat(num_samples / 2)) - 1, :], (-1, window_len)).T - - # Calculate correlation for each channel - for c in range(len(usable_channels)): - numerator = np.sum(XX[:, c] * YY[:, c]) - denominator = np.sqrt(np.sum(XX[:, c] ** 2)) * np.sqrt(np.sum(YY[:, c] ** 2)) - corrs[c, o] = numerator / denominator - - time_passed_list[o] = time.time() - start_time - median_time_passed = np.median(time_passed_list[: o + 1]) - if o % 50 == 0: - logger.info(f'{o + 1:3d}/{W} blocks, {median_time_passed * (W - o - 1) / 60:.1f} minutes remaining.') - - flagged = corrs < corr_threshold - - # mark all channels for removal which have more flagged samples than the maximum number of - # ignored samples - removed_channels = np.zeros(C, dtype=bool) - removed_channels[usable_channels] = np.sum(flagged, axis=1) * window_len > max_broken_time - removed_channels = removed_channels | noise_mask - - # apply removal - if np.mean(removed_channels) > 0.75: - raise ValueError( - 'More than 75% of your channels were removed -- this is probably caused by incorrect channel location measurements (e.g., wrong cap design).' - ) - elif np.any(removed_channels): - try: - from eegprep import pop_select - - EEG = pop_select(EEG, nochannel=list(np.where(removed_channels)[0])) - except Exception as e: - if isinstance(e, ImportError): - logger.error("Apparently you do not have EEGLAB's pop_select() on the path.") - else: - logger.error("Could not select channels using EEGLAB's pop_select(); details: %s", str(e)) - logger.debug("Exception traceback:", exc_info=True) - - logger.info(f'Removing {np.sum(removed_channels)} channels and dropping signal meta-data.') - EEG = remove_channels_without_pop_select(EEG, removed_channels) - update_clean_channel_mask(EEG, removed_channels) - - return EEG diff --git a/src/eegprep/plugins/clean_rawdata/clean_channels_nolocs.py b/src/eegprep/plugins/clean_rawdata/clean_channels_nolocs.py deleted file mode 100644 index e20900a8..00000000 --- a/src/eegprep/plugins/clean_rawdata/clean_channels_nolocs.py +++ /dev/null @@ -1,125 +0,0 @@ -"""EEG channel cleaning utilities without locations.""" - -import logging -from typing import Any, Dict, Tuple - -import numpy as np - -from eegprep.plugins.firfilt.design import design_fir, design_kaiser - -from .private.channel_removal import remove_channels_without_pop_select, update_clean_channel_mask -from .private.sigproc import filtfilt_fast - -logger = logging.getLogger(__name__) - - -def clean_channels_nolocs( - EEG: Dict[str, Any], - min_corr: float = 0.45, - ignored_quantile: float = 0.1, - window_len: float = 2.0, - max_broken_time: float = 0.5, - linenoise_aware: bool = True, -) -> Tuple[Dict[str, Any], np.ndarray]: - """Remove channels with abnormal data from a continuous data set. - - This is an automated artifact rejection function which ensures that the data - contains no channels that record only noise for extended periods of time. If - channels with control signals are contained in the data these are usually also - removed. The criterion is based on correlation: if a channel is decorrelated - from all others (pairwise correlation < a given threshold), excluding a given - fraction of most correlated channels -- and if this holds on for a sufficiently - long fraction of the data set -- then the channel is removed. - - Args: - EEG: Continuous data set, assumed to be appropriately high-passed (e.g. >0.5Hz or - with a 0.5Hz - 2.0Hz transition band). - min_corr: Minimum correlation between a channel and any other channel (in - a short period of time) below which the channel is considered abnormal - for that time period. Reasonable range: 0.4 (very lax) to 0.6 (quite aggressive). - ignored_quantile: Fraction of channels that need to have at least the given - MinCorrelation value w.r.t. the channel under consideration. This allows - to deal with channels or small groups of channels that measure the same - noise source. Reasonable range: 0.05 (rather lax) to 0.2 (very tolerant re - disconnected/shorted channels). - window_len: Length of the windows (in seconds) for which correlation is computed. - max_broken_time: Maximum time (either in seconds or as fraction of the - recording) during which a retained channel may be broken. Reasonable - range: 0.1 (very aggressive) to 0.6 (very lax). - linenoise_aware: Whether the operation should be performed in a line-noise - aware manner. If enabled, the correlation measure will not be affected - by the presence or absence of line noise (using a temporary notch filter). - - Returns - ------- - EEG : data set with bad channels removed - removed_channels : boolean array indicating which channels were removed - """ - Fs = EEG['srate'] - - # Flag channels - if 0 < max_broken_time < 1: - max_broken_time = EEG['data'].shape[1] * max_broken_time - else: - max_broken_time = Fs * max_broken_time - - EEG['data'] = np.asarray(EEG['data'], dtype=np.float64) - C, S, *_ = EEG['data'].shape - window_len = window_len * Fs - wnd = np.arange(int(window_len)) - offsets = np.arange(0, int(S - window_len), window_len, dtype=int) - W = len(offsets) - retained = np.arange(C - int(np.ceil(C * ignored_quantile))) - - # Optionally ignore both 50 and 60 Hz spectral components - if linenoise_aware: - Bwnd = design_kaiser(2 * 45 / Fs, 2 * 50 / Fs, 60, True) - - if Fs <= 110: - raise ValueError('Sampling rate must be above 110 Hz') - elif Fs <= 130: - B = design_fir(len(Bwnd) - 1, 2 * np.array([0, 45, 50, 55, Fs / 2]) / Fs, [1, 1, 0, 1, 1], w=Bwnd) - else: - B = design_fir( - len(Bwnd) - 1, 2 * np.array([0, 45, 50, 55, 60, 65, Fs / 2]) / Fs, [1, 1, 0, 1, 0, 1, 1], w=Bwnd - ) - - X = np.zeros((S, C)) - for c in range(C): - X[:, c] = filtfilt_fast(B, 1.0, EEG['data'][c, :]) - else: - X = EEG['data'].T - - # For each window, flag channels with too low correlation to any other channel - flagged = np.zeros((C, W), dtype=bool) - for o in range(W): - window_data = X[offsets[o] + wnd, :] - corrmat = np.abs(np.corrcoef(window_data, rowvar=False)) - sortcc = np.sort(corrmat, axis=0) - flagged[:, o] = np.all(sortcc[retained, :] < min_corr, axis=0) - - # Mark channels for removal which have more flagged samples than the maximum - removed_channels = np.sum(flagged, axis=1) * window_len > max_broken_time - - # Apply removal - if np.all(removed_channels): - logger.warning('All channels are flagged bad according to the used criterion: not removing anything.') - elif np.any(removed_channels): - logger.info('Now removing bad channels...') - try: - # Try to use pop_select if available - from eegprep import pop_select - - EEG = pop_select(EEG, nochannel=list(np.where(removed_channels)[0])) - except Exception as e: - if isinstance(e, ImportError): - logger.error('Apparently you do not have access to a pop_select() function.') - else: - logger.error('Could not select channels using EEGLAB\'s pop_select(); details: %s', str(e)) - logger.debug('Exception traceback:', exc_info=True) - - logger.info('Falling back to a basic substitute and dropping signal meta-data.') - EEG = remove_channels_without_pop_select(EEG, removed_channels) - update_clean_channel_mask(EEG, removed_channels) - - return EEG, removed_channels diff --git a/src/eegprep/plugins/clean_rawdata/clean_drifts.py b/src/eegprep/plugins/clean_rawdata/clean_drifts.py deleted file mode 100644 index 1d850474..00000000 --- a/src/eegprep/plugins/clean_rawdata/clean_drifts.py +++ /dev/null @@ -1,52 +0,0 @@ -"""EEG drift removal utilities.""" - -import logging -from typing import Any, Dict, Sequence - -import numpy as np -from scipy.signal import filtfilt - -from eegprep.plugins.firfilt.design import design_fir, design_kaiser - -from .private.sigproc import filtfilt_fast - -logger = logging.getLogger(__name__) - - -def clean_drifts( - EEG: Dict[str, Any], - transition: Sequence[float] = (0.5, 1), - attenuation: float = 80.0, - method: str = 'fft', -) -> Dict[str, Any]: - """Remove drifts from the data using a forward-backward high-pass filter. - - This removes drifts from the data using a forward-backward (non-causal) filter. - NOTE: If you are doing directed information flow analysis, do no use this filter but some other one. - - Args: - EEG: the continuous-time EEG data structure - transition: the transition band in Hz, i.e. lower and upper edge of the - transition as in (lo,hi) - attenuation: stop-band attenuation, in dB - method: the method to use for filtering ('fft' or 'fir') - - Returns - ------- - EEG : the filtered EEG data structure - """ - EEG['data'] = np.asarray(EEG['data'], dtype=np.float64) - - # design highpass FIR filter - transition = 2 * np.asarray(transition) / EEG['srate'] - wnd = design_kaiser(transition[0], transition[1], attenuation, True) - B = design_fir(len(wnd) - 1, np.concatenate(([0], transition, [1])), [0, 0, 1, 1], w=wnd) - - op = filtfilt if method == 'fir' else filtfilt_fast - - # apply it, channel by channel to save memory - for i in range(EEG['data'].shape[0]): - EEG['data'][i, :] = op(B, 1, EEG['data'][i, :]) - EEG['etc']['clean_drifts_kernel'] = B - - return EEG diff --git a/src/eegprep/plugins/clean_rawdata/clean_flatlines.py b/src/eegprep/plugins/clean_rawdata/clean_flatlines.py deleted file mode 100644 index bb2c4207..00000000 --- a/src/eegprep/plugins/clean_rawdata/clean_flatlines.py +++ /dev/null @@ -1,67 +0,0 @@ -"""EEG flatline channel removal utilities.""" - -from typing import Any, Dict -import logging - -import numpy as np - -from .private.channel_removal import remove_channels_without_pop_select, update_clean_channel_mask - -logger = logging.getLogger(__name__) - - -def clean_flatlines(EEG: Dict[str, Any], max_flatline_duration: float = 5.0, max_allowed_jitter: float = 20.0): - """Remove (near-) flat-lined channels. - - This is an automated artifact rejection function which ensures that - the data contains no flat-lined channels. - - Args: - EEG: the continuous-time EEG data structure - max_flatline_duration: maximum tolerated flatline duration. In seconds. - If a channel has a longer flatline than this, it will be considered - abnormal. - max_allowed_jitter: maximum tolerated jitter during flatlines. As a - multiple of epsilon. - - Returns - ------- - EEG : the EEG data structure with flatlined channels removed. - - Example: - EEG = clean_flatlines(EEG) - """ - X = EEG['data'] - max_duration = max_flatline_duration * EEG['srate'] - max_jitter = max_allowed_jitter * np.finfo(np.float64).eps - - # flag channels - removed_channels = np.zeros(X.shape[0], dtype=bool) - for i in range(X.shape[0]): - flat = np.pad(np.abs(np.diff(X[i, :])) < max_jitter, 1) - flat_intervals = np.reshape(np.where(np.diff(flat) > 0)[0], (-1, 2)) - if flat_intervals.shape[0] > 0: - if np.max(flat_intervals[:, 1] - flat_intervals[:, 0]) > max_duration: - removed_channels[i] = True - - # remove them - if np.all(removed_channels): - logger.warning('All channels have a flat-line portion; not removing anything.') - elif np.any(removed_channels): - # noinspection PyBroadException - try: - # noinspection PyUnresolvedReferences - from eegprep import pop_select - - EEG = pop_select(EEG, nochannel=list(np.where(removed_channels)[0])) - except Exception as e: - if isinstance(e, ImportError): - logger.error('Apparently you do not have access to a pop_select() function.') - else: - logger.error('Could not select channels using EEGLAB\'s pop_select(); details: %s', str(e)) - logger.debug('Exception traceback:', exc_info=True) - logger.info('Falling back to a basic substitute and dropping signal meta-data.') - EEG = remove_channels_without_pop_select(EEG, removed_channels) - update_clean_channel_mask(EEG, removed_channels) - - return EEG diff --git a/src/eegprep/plugins/clean_rawdata/clean_windows.py b/src/eegprep/plugins/clean_rawdata/clean_windows.py deleted file mode 100644 index 732d17dd..00000000 --- a/src/eegprep/plugins/clean_rawdata/clean_windows.py +++ /dev/null @@ -1,214 +0,0 @@ -"""EEG data window cleaning utilities. - -This module provides functions for removing periods with abnormally high-power content -from continuous EEG data. -""" - -import logging -from copy import deepcopy -from typing import Any, Dict, Sequence, Tuple, Union - -import numpy as np - -from ...functions.miscfunc.misc import round_mat -from ...functions.popfunc.eeg_eegrej import eeg_eegrej -from .private.masks import mask_to_intervals -from .private.stats import fit_eeg_distribution - -logger = logging.getLogger(__name__) - - -def clean_windows( - EEG: Dict[str, Any], - max_bad_channels: Union[int, float] = 0.2, - zthresholds: Tuple[float, float] = (-3.5, 5), - window_len: float = 1.0, - window_overlap: float = 0.66, - max_dropout_fraction: float = 0.1, - min_clean_fraction: float = 0.25, - truncate_quant: Tuple[float, float] = (0.022, 0.6), - step_sizes: Tuple[float, float] = (0.01, 0.01), - shape_range: Union[np.ndarray, Sequence[float]] = np.arange(1.7, 3.6, 0.15), -) -> Tuple[Dict[str, Any], np.ndarray]: - """Remove periods with abnormally high-power content from continuous data. - - This function cuts segments from the data which contain high-power artifacts. - Specifically, only windows are retained which have less than a certain - fraction of *bad* channels, where a channel is bad in a window if its RMS - power is above or below some *z*-threshold relative to a robust estimate - of clean EEG power in that channel. - - Args - ---- - EEG : dict - Continuous dataset using the EEGLAB dict schema. The data is - expected to be high-passed appropriately (>1 Hz recommended). - max_bad_channels : int | float - The maximum number **or** fraction of channels that may exceed the - thresholds inside a time-window for the window to be kept. Values in - (0,1) are interpreted as a fraction; otherwise as an absolute count. - zthresholds : tuple(float, float) - Lower and upper *z*-score limits for RMS power ([low, high]). - window_len : float - Window length in seconds. Should be at least half a period of the high- - pass cut-off that was used. Default is 1 s. - window_overlap : float - Fractional overlap between consecutive windows (0-1). Higher overlap - finds more artefacts but is slower. Default is 0.66 (≈⅔ overlap). - max_dropout_fraction : float - Maximum fraction of windows that may have arbitrarily low amplitude - (e.g. sensor unplugged). Default is 0.1. - min_clean_fraction : float - Minimum fraction of windows expected to be clean (essentially - uncontaminated EEG). Default is 0.25. - truncate_quant : tuple(float, float) - Quantile range of the truncated Gaussian to fit (default (0.022,0.6)). - step_sizes : tuple(float, float) - Grid-search step sizes in quantiles for lower/upper edge. - shape_range : sequence(float) - Range for the *beta* shape parameter in the generalised Gaussian used - for distribution fitting. - - Returns - ------- - EEG : dict - The passed-in structure with bad time periods excised. - sample_mask : np.ndarray[bool] - Boolean mask (length == original ``pnts``) indicating which samples are - retained (``True``) or removed (``False``). - """ - # ------------------------------------------------------------------ - # Input handling - # ------------------------------------------------------------------ - # Operate on a deep copy so the caller's dataset is never mutated. - EEG = deepcopy(EEG) - input_data = np.asarray(EEG['data']) - output_dtype = input_data.dtype if np.issubdtype(input_data.dtype, np.floating) else np.dtype(np.float64) - EEG['data'] = input_data.astype(np.float64, copy=True) - C, S = EEG['data'].shape - Fs = EEG['srate'] - - # Convert fractional parameters to absolute where necessary - if C == 0 or S == 0: - raise ValueError('Empty data array encountered.') - - if max_bad_channels is not None and 0 < max_bad_channels < 1: - max_bad_channels = int(round_mat(C * max_bad_channels)) - else: - max_bad_channels = int(max_bad_channels) - - shape_range = np.asarray(shape_range) - - # ------------------------------------------------------------------ - # Prepare window indexing helpers - # ------------------------------------------------------------------ - N = int(round_mat(window_len * Fs)) # samples per window - if N <= 0: - raise ValueError('Window length too small - results in N <= 0.') - - # MATLAB: offsets = round(1:N*(1-window_overlap):S-N) - step = N * (1.0 - window_overlap) - if step <= 0: - # Avoid infinite loop when overlap >= 1 - step = 1.0 - offsets = round_mat(np.arange(1, S - N + 1, step)).astype(int) - 1 - if len(offsets) == 0: - raise ValueError('Not enough data for even a single window.') - - wnd = np.arange(N, dtype=int) - W = len(offsets) - - logger.info('Determining time window rejection thresholds...') - - # ------------------------------------------------------------------ - # Compute z-score per channel - # ------------------------------------------------------------------ - wz = np.zeros((C, W), dtype=float) - for c in reversed(range(C)): - # compute RMS amplitude for each window - Xsq = EEG['data'][c, :] ** 2 # power - # Gather samples for all windows using broadcasting (W, N) - indices = offsets[:, None] + wnd[None, :] - # Extract data and compute RMS per window - rms = np.sqrt(np.sum(Xsq[indices], axis=1) / N) - - # Fit distribution to clean EEG portion - mu, sig, *_ = fit_eeg_distribution( - rms, - min_clean_fraction=min_clean_fraction, - max_dropout_fraction=max_dropout_fraction, - quants=truncate_quant, - step_sizes=step_sizes, - beta=shape_range, - ) - if sig == 0 or np.isnan(sig): - # Fallback to robust MAD if fitting failed - sig = np.median(np.abs(rms - np.median(rms))) * 1.4826 - mu = np.median(rms) - if sig == 0: - sig = 1.0 # avoid division by zero - - # z-score relative to fitted distribution - wz[c, :] = (rms - mu) / sig - logger.info('done.') - - # ------------------------------------------------------------------ - # Identify windows to be removed/kept - # ------------------------------------------------------------------ - swz = np.sort(wz, axis=0) # sort each column (ascending) - - remove_mask = np.zeros(W, dtype=bool) - zmin, zmax = zthresholds - if zmax > 0: - # upper threshold – check the (max_bad_channels+1)-th largest value - idx_hi = C - max_bad_channels - 1 # zero-based index - idx_hi = max(min(idx_hi, C - 1), 0) - remove_mask |= swz[idx_hi, :] > zmax - if zmin < 0: - # lower threshold – check the (max_bad_channels+1)-th smallest value - idx_lo = max_bad_channels - idx_lo = max(min(idx_lo, C - 1), 0) - remove_mask |= swz[idx_lo, :] < zmin - - removed_windows = np.where(remove_mask)[0] - - # ------------------------------------------------------------------ - # Convert window removals to sample mask - # ------------------------------------------------------------------ - sample_mask = np.ones(S, dtype=bool) - for w in removed_windows: - start = offsets[w] - sample_mask[start : start + N] = False - - kept_pct = 100.0 * np.mean(sample_mask) - kept_seconds = np.count_nonzero(sample_mask) / Fs - logger.info(f'Keeping {kept_pct:.1f}% ({kept_seconds:.0f} seconds) of the data.') - - # ------------------------------------------------------------------ - # Apply sample rejection - # ------------------------------------------------------------------ - rejected_intervals = mask_to_intervals(sample_mask, value=False) - if rejected_intervals.size: - EEG = eeg_eegrej(EEG, rejected_intervals) - EEG['data'] = np.asarray(EEG['data'], dtype=output_dtype) - - # ------------------------------------------------------------------ - # Update/insert clean_sample_mask - # ------------------------------------------------------------------ - if 'etc' not in EEG: - EEG['etc'] = {} - - etc = EEG['etc'] - if 'clean_sample_mask' in etc: - prev_mask = np.asarray(etc['clean_sample_mask']).astype(bool) - one_inds = np.where(prev_mask)[0] - if len(one_inds) == len(sample_mask): - prev_mask[one_inds] = sample_mask - etc['clean_sample_mask'] = prev_mask - else: - logger.warning('EEG.etc.clean_sample is present but incompatible; it is being overwritten.') - etc['clean_sample_mask'] = sample_mask - else: - etc['clean_sample_mask'] = sample_mask - - return EEG, sample_mask diff --git a/src/eegprep/plugins/clean_rawdata/menu.py b/src/eegprep/plugins/clean_rawdata/menu.py deleted file mode 100644 index 4e320292..00000000 --- a/src/eegprep/plugins/clean_rawdata/menu.py +++ /dev/null @@ -1,15 +0,0 @@ -"""clean_rawdata plugin menu spec for the EEGPrep main window.""" - -from __future__ import annotations - -from eegprep.functions.guifunc.menu_spec import MenuItemSpec, menu_item - - -def clean_rawdata_menu() -> MenuItemSpec: - """Return the EEGLAB clean_rawdata Tools menu item.""" - return menu_item( - "Reject data using Clean Rawdata and ASR", - action="pop_clean_rawdata", - userdata="startup:off;epoch:off;study:on", - origin="clean_rawdata", - ) diff --git a/src/eegprep/plugins/clean_rawdata/pop_clean_rawdata.py b/src/eegprep/plugins/clean_rawdata/pop_clean_rawdata.py deleted file mode 100644 index e9f796f3..00000000 --- a/src/eegprep/plugins/clean_rawdata/pop_clean_rawdata.py +++ /dev/null @@ -1,283 +0,0 @@ -"""EEGLAB-style pop wrapper for clean_rawdata.""" - -from __future__ import annotations - -import copy -import logging -import re -from typing import Any - -import numpy as np - -from eegprep.functions.guifunc.inputgui import inputgui -from eegprep.functions.guifunc.spec import CallbackSpec, ControlSpec, DialogSpec -from eegprep.functions.popfunc._pop_utils import ( - format_history_value, - parse_key_value_args, - parse_text_tokens, -) -from eegprep.plugins.clean_rawdata.clean_artifacts import clean_artifacts -from eegprep.plugins.clean_rawdata.vis_artifacts import vis_artifacts - - -logger = logging.getLogger(__name__) -_SHOW_VIS_ARTIFACTS_KEY = "_show_vis_artifacts" - -_OPTION_ALIASES = { - "channelcriterion": "ChannelCriterion", - "linenoisecriterion": "LineNoiseCriterion", - "burstcriterion": "BurstCriterion", - "windowcriterion": "WindowCriterion", - "highpass": "Highpass", - "flatlinecriterion": "FlatlineCriterion", - "burstrejection": "BurstRejection", - "distance": "Distance", - "channels": "Channels", - "channels_ignore": "Channels_ignore", -} - - -def pop_clean_rawdata( - EEG, - *args, - gui: bool | None = None, - renderer=None, - return_com: bool = False, - **kwargs, -): - """Clean continuous EEG data using the clean_rawdata workflow.""" - if EEG is None: - return (None, "") if return_com else None - options = _normalise_options(parse_key_value_args(args, kwargs, lowercase_keys=False)) - show_vis_artifacts = bool(options.pop(_SHOW_VIS_ARTIFACTS_KEY, False)) - if gui is None: - gui = not bool(options) - if gui: - gui_options = _run_gui(EEG[0] if isinstance(EEG, list) else EEG, renderer=renderer) - if gui_options is None: - return (EEG, "") if return_com else EEG - show_vis_artifacts = bool(gui_options.pop(_SHOW_VIS_ARTIFACTS_KEY, False)) - options.update(gui_options) - if isinstance(EEG, list): - output = [] - for index, item in enumerate(EEG, start=1): - logger.info("Processing group dataset %s of %s.", index, len(EEG)) - output.append(pop_clean_rawdata(item, gui=False, **options)) - command = _history_command(options) - if show_vis_artifacts: - logger.warning("clean_rawdata visual rejected-data browser is disabled for multiple datasets.") - return (output, command) if return_com else output - if int(EEG.get("trials", 1) or 1) > 1 or np.asarray(EEG.get("data")).ndim == 3: - raise ValueError("Input data must be continuous. This data seems epoched.") - # Deep-copy so a failure (or any partial cleaning stage) leaves the caller's - # input dataset untouched and a successful run returns a distinct object. - working_eeg = copy.deepcopy(EEG) - clean_eeg, _hp, _bur, _removed_channels = clean_artifacts(working_eeg, **options) - command = _history_command(options) - if show_vis_artifacts: - vis_artifacts(clean_eeg, EEG) - logger.info("Done.") - return (clean_eeg, command) if return_com else clean_eeg - - -def pop_clean_rawdata_dialog_spec(EEG) -> DialogSpec: - """Return the EEGLAB-like dialog spec for ``pop_clean_rawdata``.""" - chanlocs = _chanloc_records(EEG.get("chanlocs", [])) - labels = tuple(str(chan.get("labels", "")) for chan in chanlocs if isinstance(chan, dict)) - winsize = max(0.5, 1.5 * float(EEG.get("nbchan", 1)) / float(EEG.get("srate", 1))) - row4 = (0.1, 0.8, 0.2, 0.3) - row = (0.1, 1, 0.3) - row2 = (0.1, 1.2, 0.1) - return DialogSpec( - title="pop_clean_rawdata()", - function_name="pop_clean_rawdata", - eeglab_source="plugins/clean_rawdata/pop_clean_rawdata.m", - geometry=(1, row, 1, 1, row4, row4, row, row, row, 1, 1, row, row2, row2, 1, 1, row, row, 1, 1), - geomvert=(1, 1, 0.3, 1, 1, 1, 1, 1, 1, 0.3, 1, 1, 1, 1, 0.3, 1, 1, 1, 0.3, 1), - size=(681, 733), - help_text="pophelp('pop_clean_rawdata')", - controls=( - ControlSpec( - "checkbox", - "Remove channel drift (data not already high-pass filtered)", - tag="filter", - value=False, - font_weight="bold", - callback=CallbackSpec("toggle_enabled", params={"source": "filter", "targets": ("filterfreqs",)}), - ), - ControlSpec("spacer"), - ControlSpec("text", "Linear filter (FIR) transition band [lo hi] in Hz", enabled=False), - ControlSpec("edit", tag="filterfreqs", value="0.25 0.75", enabled=False), - ControlSpec("spacer"), - ControlSpec("checkbox", "Process/remove channels", tag="chanrm", value=True, font_weight="bold"), - ControlSpec("spacer"), - ControlSpec("checkbox", "Only consider these channels", tag="chanuseflag", value=False), - ControlSpec( - "pushbutton", - "...", - tag="chanuse_button", - enabled=bool(labels), - callback=CallbackSpec( - "select_channels", - params={"button": "chanuse_button", "target": "chanuse", "channels": labels}, - matlab_callback="pop_chansel(get(gcbf, 'userdata'), 'field', 'labels')", - ), - ), - ControlSpec("edit", tag="chanuse", value=""), - ControlSpec("spacer"), - ControlSpec("checkbox", "Ignore these channels (ECG, EMG, ...)", tag="chanignoreflag", value=False), - ControlSpec( - "pushbutton", - "...", - tag="chanignore_button", - enabled=bool(labels), - callback=CallbackSpec( - "select_channels", - params={"button": "chanignore_button", "target": "chanignore", "channels": labels}, - matlab_callback="pop_chansel(get(gcbf, 'userdata'), 'field', 'labels')", - ), - ), - ControlSpec("edit", tag="chanignore", value=""), - ControlSpec("spacer"), - ControlSpec("checkbox", "Remove channel if it is flat for more than (seconds)", tag="rmflat", value=True), - ControlSpec("edit", tag="rmflatsec", value="5"), - ControlSpec("spacer"), - ControlSpec("checkbox", "Max acceptable high-frequency noise std dev", tag="rmnoise", value=True), - ControlSpec("edit", tag="rmnoiseval", value="4"), - ControlSpec("spacer"), - ControlSpec("checkbox", "Min acceptable correlation with nearby chans [0-1]", tag="rmcorr", value=True), - ControlSpec("edit", tag="rmcorrval", value="0.8"), - ControlSpec("spacer"), - ControlSpec( - "checkbox", - "Perform Artifact Subspace Reconstruction bad burst correction/rejection", - tag="asr", - value=True, - font_weight="bold", - ), - ControlSpec("spacer"), - ControlSpec("text", f"Max acceptable {winsize:1.1f} second window std dev"), - ControlSpec("edit", tag="asrstdval", value="20"), - ControlSpec("spacer"), - ControlSpec( - "checkbox", "Use Riemanian distance metric (not Euclidean) - beta", tag="distance", value=False - ), - ControlSpec("spacer"), - ControlSpec("spacer"), - ControlSpec( - "checkbox", "Remove bad data periods (when uncheck, correct using ASR)", tag="asrrej", value=True - ), - ControlSpec("spacer"), - ControlSpec("spacer"), - ControlSpec( - "checkbox", - "Additional removal of bad data periods", - tag="rejwin", - value=True, - font_weight="bold", - ), - ControlSpec("spacer"), - ControlSpec("text", "Acceptable [min max] channel RMS range (+/- std dev)"), - ControlSpec("edit", tag="rejwinval1", value="-Inf 7"), - ControlSpec("spacer"), - ControlSpec("text", "Maximum out-of-bound channels (%)"), - ControlSpec("edit", tag="rejwinval2", value="25"), - ControlSpec("spacer"), - ControlSpec( - "checkbox", "Pop up scrolling data window with rejected data highlighted", tag="vis", value=True - ), - ), - ) - - -def _chanloc_records(chanlocs): - if chanlocs is None: - return [] - if isinstance(chanlocs, dict): - return [chanlocs] - if isinstance(chanlocs, np.ndarray): - return list(chanlocs.ravel()) - return list(chanlocs) - - -def _run_gui(EEG, renderer=None): - result = inputgui(pop_clean_rawdata_dialog_spec(EEG), renderer=renderer) - if result is None: - return None - options: dict[str, Any] = { - "FlatlineCriterion": "off", - "ChannelCriterion": "off", - "LineNoiseCriterion": "off", - "Highpass": "off", - "BurstCriterion": "off", - "WindowCriterion": "off", - "BurstRejection": False, - "Distance": "Euclidian", - } - if result.get("filter"): - options["Highpass"] = _parse_numeric_text(result.get("filterfreqs", "")) - if result.get("chanrm"): - if result.get("chanignoreflag"): - options["Channels_ignore"] = parse_text_tokens(result.get("chanignore", "")) - if result.get("chanuseflag"): - options["Channels"] = parse_text_tokens(result.get("chanuse", "")) - if result.get("rmflat"): - options["FlatlineCriterion"] = float(result.get("rmflatsec", 5)) - if result.get("rmcorr"): - options["ChannelCriterion"] = float(result.get("rmcorrval", 0.8)) - if result.get("rmnoise"): - options["LineNoiseCriterion"] = float(result.get("rmnoiseval", 4)) - if result.get("asr"): - options["BurstCriterion"] = float(result.get("asrstdval", 20)) - if result.get("distance"): - options["Distance"] = "Riemannian" - if result.get("rejwin"): - options["WindowCriterionTolerances"] = _parse_numeric_text(result.get("rejwinval1", "")) - options["WindowCriterion"] = float(result.get("rejwinval2", 25)) / 100.0 - if result.get("asrrej") and options["BurstCriterion"] != "off": - options["BurstRejection"] = True - options[_SHOW_VIS_ARTIFACTS_KEY] = bool(result.get("vis")) - return options - - -def _normalise_options(options): - normalised = {} - for key, value in options.items(): - canonical = _OPTION_ALIASES.get(str(key).lower(), key) - if canonical == "BurstRejection": - value = _as_bool(value) - normalised[canonical] = value - return normalised - - -def _as_bool(value): - if isinstance(value, str): - return value.lower() == "on" - return bool(value) - - -def _parse_numeric_text(text): - values = [] - for value in re.split(r"[\s,]+", str(text).strip().strip("[]")): - if not value: - continue - if value.lower() == "-inf": - values.append(-np.inf) - elif value.lower() == "inf": - values.append(np.inf) - else: - values.append(float(value)) - return values - - -def _history_command(options): - if not options: - return "EEG = pop_clean_rawdata(EEG);" - parts = [] - for key, value in options.items(): - parts.extend([_clean_rawdata_history_value(key), _clean_rawdata_history_value(value)]) - return f"EEG = pop_clean_rawdata(EEG, {', '.join(parts)});" - - -def _clean_rawdata_history_value(value): - return format_history_value(value, bool_style="onoff", empty_sequence="{}") diff --git a/src/eegprep/plugins/clean_rawdata/private/__init__.py b/src/eegprep/plugins/clean_rawdata/private/__init__.py deleted file mode 100644 index c73917bd..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Private clean_rawdata helper ports.""" diff --git a/src/eegprep/plugins/clean_rawdata/private/channel_removal.py b/src/eegprep/plugins/clean_rawdata/private/channel_removal.py deleted file mode 100644 index eceb7717..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/channel_removal.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Shared clean_rawdata channel-removal helpers.""" - -from __future__ import annotations - -from typing import Any - -import numpy as np - -CHANNEL_DEPENDENT_FIELDS = ("icawinv", "icasphere", "icaweights", "icaact", "stats", "specdata", "specicaact") - - -def remove_channels_without_pop_select(EEG: dict[str, Any], removed_channels: np.ndarray) -> dict[str, Any]: - """Remove channels when ``pop_select`` is unavailable.""" - removed = np.asarray(removed_channels, dtype=bool).ravel() - data = np.asarray(EEG["data"], dtype=np.float32) - chanlocs = EEG.get("chanlocs", []) - if len(chanlocs) == data.shape[0]: - if isinstance(chanlocs, np.ndarray): - EEG["chanlocs"] = chanlocs[~removed] - else: - EEG["chanlocs"] = np.asarray([chanloc for index, chanloc in enumerate(chanlocs) if not removed[index]]) - EEG["data"] = data[~removed, ...] - EEG["nbchan"] = EEG["data"].shape[0] - for field in CHANNEL_DEPENDENT_FIELDS: - if field in EEG: - EEG[field] = np.array([]) - return EEG - - -def update_clean_channel_mask(EEG: dict[str, Any], removed_channels: np.ndarray) -> None: - """Update ``EEG.etc.clean_channel_mask`` after current-channel removal.""" - removed = np.asarray(removed_channels, dtype=bool).ravel() - etc = EEG.setdefault("etc", {}) - mask = etc.get("clean_channel_mask") - if mask is not None: - existing = np.asarray(mask, dtype=bool).ravel() - if int(np.sum(existing)) == removed.size: - updated = np.array(existing, dtype=bool, copy=True) - updated[updated] = ~removed - etc["clean_channel_mask"] = updated - return - etc["clean_channel_mask"] = ~removed diff --git a/src/eegprep/plugins/clean_rawdata/private/covariance.py b/src/eegprep/plugins/clean_rawdata/private/covariance.py deleted file mode 100644 index cd646640..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/covariance.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Tools for working with covariance matrices or stacks thereof.""" - -# Copyright (c) 2015-2025 Syntrogi Inc. dba Intheon. - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import logging - -import numpy as np - -from eegprep.functions.miscfunc.misc import finite_matmul - -logger = logging.getLogger(__name__) - -__all__ = ['cov_mean', 'cov_logm', 'cov_expm', 'cov_powm', 'cov_sqrtm', 'cov_rsqrtm', 'cov_sqrtm2', 'cov_shrinkage'] - - -def diag_nd(M): - """Like np.diag, but in case of a ...,N, returns a ...,N,N array of diag matrices.""" - *dims, N = M.shape - if dims: - cat = np.concatenate([np.diag(d) for d in M.reshape((-1, N))]) - return np.reshape(cat, dims + [N, N]) - else: - return np.diag(M) - - -def cov_logm(C): - """Calculate the matrix logarithm of a covariance matrix or ...,N,N array.""" - D, V = np.linalg.eigh(C) - return finite_matmul(finite_matmul(V, diag_nd(np.log(D))), V.swapaxes(-2, -1)) - - -def cov_expm(C): - """Calculate the matrix exponent of a covariance matrix or ...,N,N array.""" - D, V = np.linalg.eigh(C) - return finite_matmul(finite_matmul(V, diag_nd(np.exp(D))), V.swapaxes(-2, -1)) - - -def cov_powm(C, exp): - """Calculate a matrix power of a covariance matrix or ...,N,N array.""" - D, V = np.linalg.eigh(C) - return finite_matmul(finite_matmul(V, diag_nd(D**exp)), V.swapaxes(-2, -1)) - - -def cov_sqrtm(C): - """Calculate the matrix square root of a covariance matrix or ...,N,N array.""" - D, V = np.linalg.eigh(C) - return finite_matmul(finite_matmul(V, diag_nd(np.sqrt(D))), V.swapaxes(-2, -1)) - - -def cov_rsqrtm(C): - """Calculate the matrix reciprocal square root of a covariance matrix or ...,N,N array.""" - D, V = np.linalg.eigh(C) - return finite_matmul(finite_matmul(V, diag_nd(1.0 / np.sqrt(D))), V.swapaxes(-2, -1)) - - -def cov_sqrtm2(C): - """Calculate the matrix square root, and its reciprocal, for a covariance matrix or ...,N,N array.""" - D, V = np.linalg.eigh(C) - sqrtD = np.sqrt(D) - return ( - finite_matmul(finite_matmul(V, diag_nd(sqrtD)), V.swapaxes(-2, -1)), - finite_matmul(finite_matmul(V, diag_nd(1.0 / sqrtD)), V.swapaxes(-2, -1)), - ) - - -def cov_mean(X, *, weights=None, robust=False, iters=50, tol=1e-5, huber=0, nancheck=False, verbose=False): - """Calculate the (weighted) average of a set of covariance matrices on the manifold of SPD matrices, optionally robustly using the geometric median or Huber mean. - - Args: - X: a M,N,N array of covariance matrices - weights: optionally a vector of sample weights (can be unnormalized) - robust: whether to use a robust estimator - iters: maximum number of iterations - huber: huber threshold (delta parameter); can be set to - * None: use regular least-squares solution - * 0: use geometric / l1 median - * >0: use a Huber mean with the given value as the threshold - tol: tolerance for convergence check - nancheck: check for NaNs - verbose: generate verbose output (will print deviations in huber=None mode) - - Returns - ------- - the N,N mean covariance matrix - """ - # This algorithm is based on: - # [1] Ostresh et al., 1978, "On the Convergence of a Class of Iterative Methods for Solving the Weber Location Problem" - # [2] Fletcher et al., 2004, "Principal Geodesic Analysis on Symmetric Spaces: Statistics of Diffusion Tensors" - # [3] Fletcher et al. 2010, "The geometric median on Riemannian manifolds with application to robust atlas estimation" - # [4] Barachant et al., 2014, "Multiclass Brain-Computer Interface Classification by Riemannian Geometry" - weights = np.ones(len(X)) if weights is None else np.asarray(weights) - scales = weights - - mu = np.sum(X * weights[:, None, None], axis=0) / np.sum(weights) - # step size and divergence check threshold - step, thresh = 1.0, 1e20 - for i in range(iters): - mu_sqrt, mu_rsqrt = cov_sqrtm2(mu) - # linearize around mu (this would be the tangent space, but we omit - # the pre/post-multiplied mu_sqrt terms since they cancel in both - # the scale calculation and the exponential map) - Xt = cov_logm(finite_matmul(finite_matmul(mu_rsqrt, X), mu_rsqrt)) - # geometric-median correction (downweight each pt by its riemannian - # distance from mu, which we calc here after linearization) - if robust: - # deviations/errors per sample - d = np.sqrt(np.sum(np.square(Xt), axis=(-2, -1))) - # apply robust scale factor to provided sample weights - if huber is None: - scales = weights - if verbose: - logger.info(f"median deviations: {np.median(d)}") - elif huber == 0: - scales = weights / d - else: - w = np.where(d <= huber, 1, huber / d) - scales = weights * w - # get update Jacobian (np.average takes care of renormalization) - J = np.sum(Xt * scales[:, None, None], axis=0) / np.sum(scales) - # apply update on manifold - mu = finite_matmul(finite_matmul(mu_sqrt, cov_expm(step * J)), mu_sqrt) - # convergence checks - Jnorm = np.sqrt(np.sum(np.square(J))) - if Jnorm < tol or step < tol: - break - h = step * Jnorm - if h < thresh: - # exponentially decaying learning rate - step *= 0.95 - thresh = h - else: - # prevent blow-up - step /= 2 - if nancheck and np.any(np.isnan(mu)): - raise RuntimeError("NaNs occurred in cov_mean()") - return mu - - -def cov_shrinkage(cov, shrinkage=0, *, target='eye'): - """Regularize the given covariance matrix or stack of matrices using shrinkage. - - Args: - cov: the covariance matrix (N,N) or stack of matrices (...,N,N). - shrinkage: degree of shrinkage, between 0 and 1 - target: target matrix to shrink towards; can be: - 'eye': the identity matrix (classic shrinkage; good for small values - of shrinkage) - 'scaled-eye': the identity matrix, scaled to the average variance - of the data (can be practical when shrinkage degree is large, since - otherwise whitening will not have unit variance) - 'diag': the diagonal of the covariance matrix (diagonal shrinkage) - - Returns - ------- - the regularized covariance matrix or stack of matrices. - """ - if not shrinkage: - return cov # early exit - - N = cov.shape[-1] - - if target == 'eye': - # create a stack of identity matrices matching cov's shape - eye_target = np.zeros_like(cov) - eye_target[..., range(N), range(N)] = 1 - elif target == 'scaled-eye': - # calculate trace for each matrix in the stack (or single matrix) - # trace_cov will have shape cov.shape[:-2] or be scalar if cov is 2D - trace_cov = np.trace(cov, axis1=-2, axis2=-1) - scale = trace_cov / N - - # create a base stack of identity matrices - eye_base = np.zeros_like(cov) - eye_base[..., range(N), range(N)] = 1 - - # apply scaling - scale_val = scale - if cov.ndim > 2: - scale_val = scale[..., np.newaxis, np.newaxis] - eye_target = eye_base * scale_val - elif target == 'diag': - # get the main diagonal of each matrix in the stack - main_diagonals = np.diagonal(cov, axis1=-2, axis2=-1) - # create a stack of diagonal matrices - eye_target = diag_nd(main_diagonals) - else: - raise ValueError(f'Unsupported shrinkage target: {target}') - - cov_regu = shrinkage * eye_target + (1 - shrinkage) * cov - return cov_regu diff --git a/src/eegprep/plugins/clean_rawdata/private/masks.py b/src/eegprep/plugins/clean_rawdata/private/masks.py deleted file mode 100644 index b58d08ce..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/masks.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Boolean-mask helpers for clean_rawdata ports.""" - -import numpy as np - - -def mask_to_intervals(mask: np.ndarray, *, value: bool) -> np.ndarray: - """Convert a boolean sample mask to EEGLAB-style sample intervals.""" - target = np.asarray(mask, dtype=bool) == value - if not np.any(target): - return np.empty((0, 2), dtype=int) - padded = np.concatenate([[False], target, [False]]) - diff = np.diff(padded.astype(int)) - starts = np.where(diff == 1)[0] + 1 - ends = np.where(diff == -1)[0] - return np.stack([starts, ends], axis=1).astype(int) diff --git a/src/eegprep/plugins/clean_rawdata/private/ransac.py b/src/eegprep/plugins/clean_rawdata/private/ransac.py deleted file mode 100644 index 1a80bee1..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/ransac.py +++ /dev/null @@ -1,152 +0,0 @@ -"""RANSAC utilities for EEG data processing.""" - -from typing import Optional - -import numpy as np - -from ....functions.adminfunc.eeglabcompat import get_eeglab -from ....functions.miscfunc.misc import round_mat -from .sphericalSplineInterpolate import sphericalSplineInterpolate - - -def rand_sample(n: int, m: int, stream: np.random.RandomState) -> np.ndarray: - """Random sampling without replacement using Fisher-Yates shuffle. - - Optimized O(n) implementation using swap-based Fisher-Yates instead of - the previous O(n²) delete-based approach. Returns first m elements of - a random permutation of n items. - - Args: - n: number of items to sample from - m: number of items to sample - stream: random number generator - - Returns: - random_sample: array of m sampled values (indices 0..n-1) - - Performance: - O(n) time complexity (was O(n²) in previous implementation) - For n=1M: ~3s (was ~80s) - 25x faster - - Note: - This implementation uses Fisher-Yates shuffle for efficiency. - Results differ from the old O(n²) delete-based implementation, - but maintain parity with MATLAB's optimized rand_sample. - """ - # Start with identity permutation - pool = np.arange(n) - - # Fisher-Yates shuffle: only shuffle first m elements - for k in range(m): - # Choose from remaining elements (k to n-1) - remaining = n - k - choice = int(round_mat((remaining - 1) * stream.rand())) - - # Swap pool[k] with pool[k + choice] - idx = k + choice - pool[k], pool[idx] = pool[idx], pool[k] - - # Return first m elements - return pool[:m].copy() - - -def rand_permutation(n: int, stream: np.random.RandomState) -> np.ndarray: - """Random permutation with MATLAB parity using Fisher-Yates shuffle. - - This function produces the SAME permutation sequence as MATLAB's - rand_permutation() when both use the same RNG seed (5489). It achieves - parity by using rand() + round_mat() in a Fisher-Yates shuffle pattern - that matches MATLAB's implementation. - - Optimized O(n) implementation (was O(n²) in previous version). - - Args: - n: number of items to permute (returns permutation of 0..n-1) - stream: random number generator (np.random.RandomState) - - Returns: - permutation: array of indices 0..n-1 in random order - - Performance: - O(n) time complexity (was O(n²)) - For n=1M: ~3s (was ~80s) - 25x faster - - Example: - >>> rng = np.random.RandomState(5489) - >>> perm = rand_permutation(10, rng) - >>> # Matches MATLAB: rng(5489,'twister'); rand_permutation(10) - 1 - - Note: - This function is critical for ICA parity between Python and MATLAB. - Uses Fisher-Yates shuffle for O(n) performance. - Results differ from old O(n²) implementation but maintain - cross-platform parity with MATLAB. - See test_parity_rng.py for verification tests. - """ - # Start with identity permutation [0, 1, 2, ..., n-1] - result = np.arange(n) - - # Fisher-Yates shuffle: iterate backward from n-1 to 1 - for k in range(n - 1, 0, -1): - # Pick random index from 0 to k (inclusive) - j = int(round_mat(k * stream.rand())) - - # Swap elements k and j - result[k], result[j] = result[j], result[k] - - return result - - -def calc_projector( - locs: np.ndarray, - num_samples: int, - subset_size: int, - stream: Optional[np.random.RandomState] = None, - subroutine: str = 'sphericalSplineInterpolate', -) -> np.ndarray: - """Calculate a bag of reconstruction matrices from random channel subsets. - - Args: - locs: Nx3 array of channel locations - num_samples: number of random samples to generate - subset_size: size of each random subset - stream: optionally the random number generator to use; - if not specified, will default to a fixed seed (435656) - subroutine: which interpolation subroutine to use (for testing) - - Returns - ------- - P : combined projector matrix - """ - if stream is None: - stream = np.random.RandomState(435656) - - # noinspection PyUnresolvedReferences - rand_samples = np.zeros((locs.shape[0], num_samples, locs.shape[0])) - - if subroutine == 'sphericalSplineInterpolate': - - def op(src, dest): - return sphericalSplineInterpolate(src.T, dest.T)[0] - - elif subroutine == 'matlab': - matlab = get_eeglab('MAT') - - def op(src, dest): - return matlab.sphericalSplineInterpolate(src.T, dest.T)[0] - - elif subroutine == 'octave': - octave = get_eeglab('OCT') - - def op(src, dest): - return octave.sphericalSplineInterpolate(src.T, dest.T)[0] - - else: - raise ValueError(f'Unknown subroutine: {subroutine}') - - # noinspection PyShadowingNames - for k in range(num_samples - 1, -1, -1): - sample = rand_sample(locs.shape[0], subset_size, stream) - tmp = op(locs[sample, :], locs) - rand_samples[sample, k, :] = np.real(tmp).T - return np.reshape(rand_samples, (locs.shape[0], -1)) diff --git a/src/eegprep/plugins/clean_rawdata/private/sigproc.py b/src/eegprep/plugins/clean_rawdata/private/sigproc.py deleted file mode 100644 index 3102f70b..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/sigproc.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Signal processing utilities.""" - -from typing import Union - -import numpy as np -from scipy.signal import fftconvolve - -__all__ = ['filtfilt_fast', 'moving_average'] - - -def filtfilt_fast( - b: np.ndarray, - a: Union[float, np.ndarray], - x: np.ndarray, -) -> np.ndarray: - """Apply a zero-phase forward-backward filter to a signal using FFTs. - - This is a drop-in replacement for scipy.signal.filtfilt() that is considerably faster - for long signals. - - Parameters - ---------- - b : np.ndarray - Numerator coefficients of the filter. - a : float or np.ndarray - Must be 1. - x : np.ndarray - Signal to filter (1-D array). - - Returns - ------- - np.ndarray - The filtered signal. - """ - assert a == 1, "a must be 1; use filtfilt() for IIR filters" - n = len(b) - # pad the signal at both ends - x_padded = np.pad(x, (n, n), mode='reflect', reflect_type='odd') - # filter, reverse - y_forward = fftconvolve(x_padded, b, mode='full')[::-1] - # filter, reverse - y_filtered = fftconvolve(y_forward, b, mode='full')[::-1] - # trim off padding - excess = len(y_filtered) - len(x) - y_depadded = y_filtered[excess // 2 : -excess // 2] - return y_depadded - - -def moving_average(X, *, N=3, axis=-1, Z=None, inplace=False, transform=None, init=None): - """Lfilter()-style moving average function with support for state. - - Parameters - ---------- - X : array_like - Signal to filter. - N : int, optional - Number of points that shall be averaged (window length). - axis : int, optional - Axis along which to filter; note: IF you use transform, and if - it inserts additional axes, the same index needs to work before and - after the transform (e.g., you can use negative indices to count from - the end if needed to accomplish that). - Z : object, optional - Initial state (or None). - inplace : bool, optional - Whether to overwrite the input. - transform : callable, optional - Optionally a transformation to apply to each input sample, - usually to generate higher-dimensional data; one use case is to calculate - covariance matrices per sample on the fly instead of having the moving average - to apply to and buffer potentially very large covariance data - (by passing lambda x: x[:, None] @ x[None, :]). - init : int or None, optional - How to behave on the first N samples of input; if set to 0, - this will behave as if the data were pre-pended by zeros; if set to None, - this will average the (fewer, noisier) samples in the buffer. - - Returns - ------- - X' : array_like - Filtered signal. - Z' : object - Final state (can be passed into the next call to moving_average()). - - License - ------- - Copyright (c) 2015-2025 Syntrogi Inc. dba Intheon. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - """ - - class MovAvgState: - """State representation for moving_average() filter function.""" - - def __init__(self, p, buf, acc, n): - self.p, self.buf, self.acc, self.n = p, buf, acc, n - - if transform and inplace: - raise ValueError("You cannot use inplace and transform at the same time.") - if transform is None: - - def transform(x): - return x - # we're doing some extra homework here to be able to buffer and transform the data - # without swizzling axes (which creates temporaries that can exceed the memory), - # so we have to be able to do all operations on input and output along the desired axis, - # which may also count from the end - - def slice_at(x, k): - """Generate an index slice that will slice x at the desired axis.""" - slices = [slice(None)] * x.ndim - slices[axis] = k - return tuple(slices) - - if not inplace: - # Complicated expression to generate a new shape after transform with the - # right shape at axis - Yshp = list(np.stack([transform(X[slice_at(X, 0)])], axis=axis).shape) - Yshp[axis] = X.shape[axis] - Y = np.zeros_like(X, shape=Yshp) - else: - Y = None - if not Z: - if init is None: - init_n = 0 - elif init == 0: - init_n = N - else: - raise ValueError("init must be 0 or None") - Z = MovAvgState( - p=0, buf=np.zeros_like(X[slice_at(X, [0] * N)]), acc=np.zeros_like(transform(X[slice_at(X, 0)])), n=init_n - ) - - for k in range(X.shape[axis]): - # this is basically the buffered moving average trick (updating/downdating - # the covariance matrix with each added/removed sample), but additionally - # we're allowing the samples to be transformed to e.g. higher dimensions - # to reduce buffer space, which can be very large for long moving averages - e = X[slice_at(X, k)] - Z.n += 1 - Z.acc += transform(e) - transform(Z.buf[slice_at(Z.buf, Z.p)]) - Z.buf[slice_at(Z.buf, Z.p)] = e - res = Z.acc / min(N, Z.n) - if inplace: - X[slice_at(X, k)] = res - else: - Y[slice_at(Y, k)] = res - Z.p = (Z.p + 1) % N - return (X if inplace else Y), Z diff --git a/src/eegprep/plugins/clean_rawdata/private/sphericalSplineInterpolate.py b/src/eegprep/plugins/clean_rawdata/private/sphericalSplineInterpolate.py deleted file mode 100644 index 8496b52f..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/sphericalSplineInterpolate.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Spatial interpolation utilities.""" - -import numpy as np - -from eegprep.functions.miscfunc.misc import finite_matmul, finite_pinv - - -# Helper function (vectorized version of MATLAB's interpMx) -def _interpMx(cosEE, order, tol): - """Compute the interpolation matrix for a set of point pairs (vectorized). - - Internal helper function for sphericalSplineInterpolate. - - Args: - cosEE (np.ndarray): Matrix of cosines of angles between points. - order (int): Order of the polynomial interpolation. - tol (float): Tolerance for the Legendre polynomial approximation convergence. - - Returns - ------- - tuple[np.ndarray, np.ndarray]: G and H matrices. - """ - x = np.asarray(cosEE) # Ensure input is a numpy array - - # Initialize variables for Legendre polynomial recurrence (vectorized) - # Using float for n even for integers to ensure float division later - n = 1.0 - Pns1 = np.ones_like(x) - Pn = x.copy() # Use a copy to avoid modifying input if it was passed by reference - - # Calculate initial terms for G and H sums - nn_plus_n = n * n + n # = 2.0 when n=1 - # Ensure float exponentiation/division - tmp = ((2.0 * n + 1.0) * Pn) / (nn_plus_n ** float(order)) - G = tmp.copy() # Start sum for G - H = nn_plus_n * tmp # Start sum for H - - # Initialize convergence tracking variables - # Initialize dG/dH with the magnitude of the first term; avoids issues if G/H start near zero - dG = np.abs(G) - dH = np.abs(H) - - # Summation loop for Legendre polynomial series (vectorized) - # Max iterations set to 500 as in the MATLAB code - for n_int in range(2, 501): - n = float(n_int) # Use float n for calculations - - # Legendre polynomial recurrence relation (vectorized) - Pns2 = Pns1 - Pns1 = Pn - Pn = ((2.0 * n - 1.0) * x * Pns1 - (n - 1.0) * Pns2) / n - - # Store old G, H for convergence check (make copies) - oG = G.copy() - oH = H.copy() - - # Calculate update term 'tmp' (vectorized) - nn_plus_n = n * n + n - # Ensure float exponentiation/division - tmp = ((2.0 * n + 1.0) * Pn) / (nn_plus_n ** float(order)) - - # Update G and H sums (vectorized) - G += tmp # update function estimate, spline interp - H += nn_plus_n * tmp # update function estimate, SLAP - - # Update moving average gradient estimate for convergence (vectorized) - # Add small epsilon to denominator to prevent potential division by zero if dG/dH were zero? - # Although, initialization above should prevent this. Let's stick to MATLAB logic. - dG = (np.abs(oG - G) + dG) / 2.0 - dH = (np.abs(oH - H) + dH) / 2.0 - - # Check for convergence (break if *all* elements meet tolerance) - # Using np.all mimics the intent that the sum converges everywhere - if np.all(dG < tol) and np.all(dH < tol): - break - - # Final scaling - G /= 4.0 * np.pi - H /= 4.0 * np.pi - - return G, H - - -# Main function mirroring the MATLAB sphericalSplineInterpolate -def sphericalSplineInterpolate(src, dest, lambda_reg=1e-5, order=4, type='spline', tol=np.finfo(float).eps): - """Interpolation matrix for spherical interpolation. Python port of Jason Farquhar's MATLAB code. - - Args: - src (np.ndarray): Source electrode positions [3 x N]. Assumes coordinates are in columns. - dest (np.ndarray): Destination electrode positions [3 x M]. Assumes coordinates are in columns. - lambda_reg (float, optional): Regularisation parameter for smoothing estimates. Defaults to 1e-5. - (Renamed from 'lambda' to avoid clash with Python keyword). - order (int, optional): Order of the polynomial interpolation to use. Defaults to 4. - type (str, optional): Interpolation type, one of 'spline' or 'slap'. Defaults to 'spline'. - 'spline' -> spherical Spline - 'slap' -> surface Laplician (aka CSD) - tol (float, optional): Tolerance for the Legendre polynomial approximation convergence. - Defaults to machine epsilon for float. - - Returns - ------- - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - W: [M x N] linear mapping matrix between old and new coords. - Gss: [N x N] interpolation matrix between source points. - Gds: [M x N] interpolation matrix from source to destination points. - Hds: [M x N] SLAP interpolation matrix from source to destination points. - - Notes - ----- - Based upon the paper: Perrin, F., Pernier, J., Bertrand, O., & Echallier, J. F. (1989). - Spherical splines for scalp potential and current density mapping. - Electroencephalography and clinical neurophysiology, 72(2), 184-187. - - Original MATLAB Copyright Notice: - Copyright 2009- by Jason D.R. Farquhar (jdrf@zepler.org) - Permission is granted for anyone to copy, use, or modify this - software and accompanying documents, provided this copyright - notice is retained, and note is made of any changes that have been - made. This software and documents are distributed without any - warranty, express or implied. - """ - # Ensure inputs are numpy arrays - src = np.asarray(src) - dest = np.asarray(dest) - - # Validate input shapes (optional but good practice) - if src.ndim != 2 or src.shape[0] != 3: - raise ValueError(f"src must be a 2D array with shape (3, N), got {src.shape}") - if dest.ndim != 2 or dest.shape[0] != 3: - raise ValueError(f"dest must be a 2D array with shape (3, M), got {dest.shape}") - - n_src = src.shape[1] - n_dest = dest.shape[1] - - # Map the positions onto the unit sphere - # Normalize each column vector - norm_src = np.sqrt(np.sum(src**2, axis=0, keepdims=True)) - # Avoid division by zero if a source position is exactly at the origin - norm_src[norm_src == 0] = 1.0 - src_norm = src / norm_src - - norm_dest = np.sqrt(np.sum(dest**2, axis=0, keepdims=True)) - # Avoid division by zero - norm_dest[norm_dest == 0] = 1.0 - dest_norm = dest / norm_dest - - # Calculate the cosine of the angle between the new and old electrodes. - # If the vectors are on top of each other, the result is 1. - # Transpose src_norm (N, 3) and dest_norm (M, 3) for matrix multiplication - cosSS = finite_matmul(src_norm.T, src_norm) # angles between source positions [N x N] - cosDS = finite_matmul(dest_norm.T, src_norm) # angles between destination positions [M x N] - - # Ensure cosines are within [-1, 1] due to potential floating point errors - cosSS = np.clip(cosSS, -1.0, 1.0) - cosDS = np.clip(cosDS, -1.0, 1.0) - - # Compute the interpolation matrices G and H using the helper function - # Pass the tolerance 'tol' from the main function call - Gss, _ = _interpMx(cosSS, order, tol) # [N x N] (Hss not needed) - Gds, Hds = _interpMx(cosDS, order, tol) # [M x N] - - # Include the regularisation - if lambda_reg > 0: - Gss = Gss + lambda_reg * np.eye(n_src) - - # Compute the mapping to the polynomial coefficients space - # N.B. this can be numerically unstable so use the PINV to solve.. - muGss = 1.0 # Fixed value as in the MATLAB code (comment mentioned median(diag(Gss))) - - # Construct matrix C - # C = [ Gss muGss*ones(N,1) ] - # [ muGss*ones(1,N) 0 ] - C = np.zeros((n_src + 1, n_src + 1)) - C[:n_src, :n_src] = Gss - C[:n_src, n_src] = muGss # Column of ones * muGss - C[n_src, :n_src] = muGss # Row of ones * muGss - # C[n_src, n_src] remains 0 - - # Calculate the pseudoinverse of C - iC = finite_pinv(C) # [N+1 x N+1] - - # Compute the final mapping matrix W based on the specified type - type_lower = type.lower() - if type_lower == 'spline': - # W = [Gds ones(M,1)*muGss] * iC[:, :-1] - # Construct the [Gds ones(M,1)*muGss] matrix part - Gds_augmented = np.hstack((Gds, muGss * np.ones((n_dest, 1)))) # [M x N+1] - # Multiply by the relevant part of iC - W = finite_matmul(Gds_augmented, iC[:, :n_src]) # [M x N+1] @ [N+1 x N] = [M x N] - - elif type_lower == 'slap': - # W = Hds * iC[:-1, :-1] - W = finite_matmul(Hds, iC[:n_src, :n_src]) # [M x N] @ [N x N] = [M x N] - - else: - raise ValueError(f"Unknown interpolation type specified: '{type}'. Must be 'spline' or 'slap'.") - - # Return the mapping matrix W and intermediate G/H matrices as per MATLAB signature - return W, Gss, Gds, Hds diff --git a/src/eegprep/plugins/clean_rawdata/private/stats.py b/src/eegprep/plugins/clean_rawdata/private/stats.py deleted file mode 100644 index 709a1855..00000000 --- a/src/eegprep/plugins/clean_rawdata/private/stats.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Statistical utilities for EEG data.""" - -import math -import logging -import numpy as np -from numpy.linalg import norm as np_norm # Use alias to avoid potential name collision -from scipy.special import gamma, gammaincinv -from ....functions.miscfunc.misc import round_mat - -logger = logging.getLogger(__name__) - - -def fit_eeg_distribution( - X, min_clean_fraction=None, max_dropout_fraction=None, quants=None, step_sizes=None, beta=None -): - """Estimate the mean and standard deviation of clean EEG from contaminated data. - - Mu,Sigma,Alpha,Beta = fit_eeg_distribution(X,MinCleanFraction,MaxDropoutFraction,FitQuantiles,StepSizes,ShapeRange) - - This function estimates the mean and standard deviation of clean EEG from a - sample of amplitude values (that have preferably been computed over short - windows) that may include a large fraction of contaminated samples. The - clean EEG is assumed to represent a generalized Gaussian component in a - mixture with near-arbitrary artifact components. By default, at least 25% - (MinCleanFraction) of the data must be clean EEG, and the rest can be - contaminated. No more than 10% (MaxDropoutFraction) of the data is - allowed to come from contaminations that cause lower-than-EEG amplitudes - (e.g., sensor unplugged). There are no restrictions on artifacts causing - larger-than-EEG amplitudes, i.e., virtually anything is handled (with the - exception of a very unlikely type of distribution that combines with the - clean EEG samples into a larger symmetric generalized Gaussian peak and - thereby "fools" the estimator). The default parameters should be fine for - a wide range of settings but may be adapted to accomodate special - circumstances. - - The method works by fitting a truncated generalized Gaussian whose - parameters are constrained by MinCleanFraction, MaxDropoutFraction, - FitQuantiles, and ShapeRange. The alpha and beta parameters of the gen. - Gaussian are also returned. The fit is performed by a grid search that - always finds a close-to-optimal solution if the above assumptions are - fulfilled. - - Args: - X : array-like - Vector of amplitude values of EEG, possible containing artifacts - (coming from single samples or windowed averages). - min_clean_fraction : float, optional - Minimum fraction of values in X that needs to be clean - (default: 0.25). - max_dropout_fraction : float, optional - Maximum fraction of values in X that can be subject to - signal dropouts (e.g., sensor unplugged) (default: 0.1). - quants : tuple or list, optional - Quantile range [lower,upper] of the truncated generalized Gaussian - distribution that shall be fit to the EEG contents - (default: (0.022, 0.6)). - step_sizes : tuple or list, optional - Step size of the grid search; the first value is the stepping of the - lower bound (which essentially steps over any dropout samples), and - the second value is the stepping over possible scales (i.e., - clean-data quantiles) (default: (0.01, 0.01)). - beta : array-like, optional - Range that the clean EEG distribution's shape parameter beta may take - (default: np.arange(1.7, 3.6, 0.15)). - - Returns - ------- - tuple: - - mu (float): estimated mean of the clean EEG distribution. - - sig (float): estimated standard deviation of the clean EEG distribution. - - alpha (float): estimated scale parameter of the generalized Gaussian - clean EEG distribution. - - beta (float): estimated shape parameter of the generalized Gaussian - clean EEG distribution. - """ - # --- Assign defaults --- - if min_clean_fraction is None: - min_clean_fraction = 0.25 - if max_dropout_fraction is None: - max_dropout_fraction = 0.1 - if quants is None: - # Use tuple for immutability, common practice in Python for fixed sequences - quants = (0.022, 0.6) - if step_sizes is None: - step_sizes = (0.01, 0.01) - if beta is None: - beta = np.arange(1.7, 3.6, 0.15) - - # Convert potentially list inputs to numpy arrays for consistency - quants = np.array(quants) - step_sizes = np.array(step_sizes) - beta = np.array(beta) - - # --- Sanity checks --- - # Use isinstance for type checking, len() for number of elements - if not isinstance(quants, np.ndarray) or quants.ndim != 1 or len(quants) != 2: - raise ValueError('Fit quantiles needs to be a 2-element vector.') - if np.any(quants < 0) or np.any(quants > 1): - raise ValueError('Unreasonable fit quantiles.') - if np.any(step_sizes < 0.0001) or np.any(step_sizes > 0.1): - raise ValueError('Unreasonable step sizes.') - # Allow slightly wider range check than MATLAB's >=7, <=1 - if np.any(beta > 7) or np.any(beta < 1): - raise ValueError('Unreasonable shape range.') - - # --- Sort data so we can access quantiles directly --- - # Ensure X is a numpy array, convert to float (like MATLAB's double), flatten - X = np.asarray(X, dtype=float).flatten() - X.sort() - n = len(X) - - # --- Calc z bounds for the truncated standard generalized Gaussian pdf and pdf rescaler --- - zbounds = [] # Use a list to store bounds for each beta - rescale = np.zeros_like(beta, dtype=float) # Pre-allocate rescale array - for i, b_val in enumerate(beta): - # Calculate bounds using gammaincinv - # Note: MATLAB's gammaincinv(A,X) finds y where gammainc(y,A,'lower') = X. - # scipy.special.gammaincinv(a, y) finds x where gammainc(a, x) = y. - # The argument sign(q-0.5)*(2*q-1) simplifies to abs(2*q-1). - # We need y such that P(1/b, y) = abs(2*q-1), so y = gammaincinv(1/b, abs(2*q-1)). - # The final z is sign(q-0.5) * y**(1/b). - lower_bound_arg = abs(2 * quants[0] - 1) - upper_bound_arg = abs(2 * quants[1] - 1) - - lower_y = gammaincinv(1.0 / b_val, lower_bound_arg) - upper_y = gammaincinv(1.0 / b_val, upper_bound_arg) - - lower_z = np.sign(quants[0] - 0.5) * np.power(lower_y, 1.0 / b_val) - upper_z = np.sign(quants[1] - 0.5) * np.power(upper_y, 1.0 / b_val) - - zbounds.append(np.array([lower_z, upper_z])) - rescale[i] = b_val / (2.0 * gamma(1.0 / b_val)) - - # --- Determine the quantile-dependent limits for the grid search --- - lower_min = np.min(quants) # we can generally skip the tail below the lower quantile - max_width = np.diff(quants)[0] # maximum width is the fit interval if all data is clean - min_width = min_clean_fraction * max_width # minimum width of the fit interval, as fraction of data - - # --- Get matrix of shifted data ranges --- - # Generate start indices based on lower quantile, dropout fraction, and step size - # Use np.arange; add a small fraction of step_sizes[0] to ensure the endpoint is included if it's a multiple of the step - start_indices = round_mat( - n * np.arange(lower_min, lower_min + max_dropout_fraction + 0.99 * step_sizes[0], step_sizes[0]) - ).astype(int) - - # Generate indices within each window based on max_width - max_window_len = int(round_mat(n * max_width)) - window_indices = np.arange(max_window_len) - - # Use broadcasting to create the matrix of indices (equivalent to bsxfun(@plus, ...)) - # Indices shape: (num_starts, max_window_len) - all_indices = start_indices[:, None] + window_indices[None, :] - - # Index into sorted data X to get the windows - X_windows = X[all_indices] - - # Get the first element (lower bound) of each window - X1 = X_windows[:, 0].copy() # Use .copy() to avoid potential view issues - - # Subtract the lower bound from each element in its respective window (equivalent to bsxfun(@minus, X, X1)) - X_shifted = X_windows - X1[:, None] # Broadcasting subtraction - - # --- Grid search --- - opt_val = np.inf - opt_beta = np.nan - opt_bounds = np.array([np.nan, np.nan]) - opt_lu = np.array([np.nan, np.nan]) # Lower and Upper data values of the optimal interval - - # Iterate through possible interval widths 'm' exactly as in-house - m_steps = np.int32(round_mat(n * np.arange(max_width, min_width, -step_sizes[1]))) - - for m in m_steps: - if m <= 0: - continue # Skip if width is non-positive - - # --- Scale and bin the data in the intervals --- - nbins = int(round_mat(3 * math.log2(1 + m / 2))) - if nbins <= 0: - continue # Skip if nbins is non-positive - - # scale and bin the data in the intervals exactly as in the MATLAB code - with np.errstate(invalid="ignore", divide="ignore"): - H = np.asarray(X_shifted[:, :m] * nbins / X_shifted[:, m - 1].reshape((-1, 1))) - H[np.isnan(H)] = -1 - bins = list(range(nbins)) - bins.append(np.inf) - logq = np.zeros((H.shape[0], nbins)) - for k in range(H.shape[0]): - h, _ = np.histogram(H[k, :], bins=bins, density=False) - logq[k, :] = np.log(np.asarray(h) + 0.01) - - # --- Inner loop: Iterate through shape parameters (beta) --- - for b in range(len(beta)): - bounds = zbounds[b] - x_vals = bounds[0] + (0.5 + np.arange(nbins)) / nbins * np.diff(bounds) - p = np.exp(-(np.abs(x_vals) ** beta[b])) * rescale[b] - p /= np.sum(p) - kl = np.sum(p * (np.log(p) - logq), axis=1) + np.log(m) - min_val = np.min(kl) - idx = np.where(kl == min_val)[0][0] - if min_val < opt_val: - opt_val = min_val - opt_beta = beta[b] - opt_bounds = bounds - opt_lu = [X1[idx], X1[idx] + X_shifted[idx, m - 1]] - - # --- Recover distribution parameters at optimum --- - if np.any(np.isnan(opt_lu)) or np.any(np.isnan(opt_bounds)): - logger.warning("Optimal parameters not found; returning NaNs.") - return np.nan, np.nan, np.nan, np.nan - - bound_diff = opt_bounds[1] - opt_bounds[0] - if abs(bound_diff) < 1e-9: - logger.warning("Optimal bounds are too close; returning NaNs.") - return np.nan, np.nan, np.nan, np.nan - - # alpha = (opt_lu(2)-opt_lu(1))/diff(opt_bounds); - alpha = (opt_lu[1] - opt_lu[0]) / bound_diff - - # mu = opt_lu(1)-opt_bounds(1)*alpha; - mu = opt_lu[0] - opt_bounds[0] * alpha - - # beta is already opt_beta - final_beta = opt_beta - - # --- Calculate the distribution's standard deviation from alpha and beta --- - # sig = sqrt((alpha^2)*gamma(3/beta)/gamma(1/beta)); - try: - gamma_3_over_beta = gamma(3.0 / final_beta) - gamma_1_over_beta = gamma(1.0 / final_beta) - if gamma_1_over_beta < 1e-9: # Avoid division by near-zero - sig = np.nan - logger.warning("gamma(1/beta) is close to zero; std dev calculation failed.") - else: - sig = np.sqrt((alpha**2) * gamma_3_over_beta / gamma_1_over_beta) - except ValueError: # Catches potential issues with gamma function inputs (e.g., non-positive) - sig = np.nan - logger.warning("Could not calculate std dev due to invalid gamma function input.") - - # Ensure output types are standard Python floats if they are scalar - mu = float(mu) if np.isscalar(mu) else mu - sig = float(sig) if np.isscalar(sig) else sig - alpha = float(alpha) if np.isscalar(alpha) else alpha - final_beta = float(final_beta) if np.isscalar(final_beta) else final_beta - - return mu, sig, alpha, final_beta - - -def geometric_median(X, tol=1.0e-5, y=None, max_iter=500): - """Calculate the geometric median for a set of observations. - - This is the mean under a Laplacian noise distribution, using - Weiszfeld's algorithm. - - Args: - X (np.ndarray): The data, expected shape (n_samples, n_features). - tol (float, optional): Tolerance for convergence. Defaults to 1.e-5. - y (np.ndarray, optional): Initial value for the geometric median. - Defaults to the coordinate-wise median of X. - max_iter (int, optional): Maximum number of iterations. Defaults to 500. - - Returns - ------- - np.ndarray: The geometric median of X, shape (n_features,). - """ - # Ensure X is a numpy array - X = np.asarray(X) - if X.ndim != 2: - raise ValueError("Input data X must be a 2D array (samples x features).") - - # Default initial value: coordinate-wise median - if y is None: - y = np.median(X, axis=0) - else: - y = np.asarray(y) - if y.shape != (X.shape[1],): - raise ValueError(f"Initial guess y must have shape ({X.shape[1]},) matching the number of features in X.") - - # Small constant to prevent division by zero if a point coincides with the median - epsilon = 1e-9 - - for i in range(max_iter): - # Calculate squared distances from each point in X to the current median y - # X shape: (n_samples, n_features), y shape: (n_features,) - # Broadcasting makes (X - y) shape (n_samples, n_features) - squared_distances = np.sum((X - y) ** 2, axis=1) - - # Calculate inverse norms (distances). Add epsilon for numerical stability. - invnorms = 1.0 / np.sqrt(squared_distances + epsilon) - - # Check for exact matches (where distance is near zero) - # If a data point coincides with the current estimate, its weight should be handled carefully. - # Weiszfeld's algorithm can be sensitive here. A common approach is to give these points - # large weight or handle them separately, but simply adding epsilon often suffices. - # Here, the epsilon already prevents division by zero. - - # Update the median estimate - # Weighted average: sum(X * weights) / sum(weights) - # weights are invnorms. Need to broadcast invnorms for element-wise multiplication. - # invnorms shape: (n_samples,), X shape: (n_samples, n_features) - # X * invnorms[:, np.newaxis] has shape (n_samples, n_features) - new_y = np.sum(X * invnorms[:, np.newaxis], axis=0) / np.sum(invnorms) - - # Store the old median and update y - oldy = y - y = new_y - - # Check for convergence: relative change in norm - # Use np.linalg.norm for vector norm - norm_y = np_norm(y) - if norm_y == 0: # Avoid division by zero if the median is the zero vector - if np_norm(y - oldy) < tol: # Check absolute difference if norm is zero - break - elif np_norm(y - oldy) / norm_y < tol: - break - - # Optional: Add a warning if max_iter was reached without convergence - if i == max_iter - 1: - logger.warning(f"Geometric median calculation did not converge within {max_iter} iterations.") - - return y - - -# Helper function ported from asr_calibrate.m -def block_geometric_median(X, blocksize=1, tol=1.0e-5, y=None, max_iter=500): - """Calculate a blockwise geometric median. - - Faster and less memory-intensive than the regular geom_median function. - This statistic is not robust to artifacts that persist over a duration that - is significantly shorter than the blocksize. - - Args: - X (np.ndarray): The data (#observations x #variables). - blocksize (int, optional): The number of successive samples over which a regular mean - should be taken. Defaults to 1. - tol (float, optional): Tolerance for convergence. Defaults to 1.e-5. - y (np.ndarray, optional): Initial value for the geometric median. - Defaults to the coordinate-wise median of X. - max_iter (int, optional): Maximum number of iterations. Defaults to 500. - - Returns - ------- - np.ndarray: Geometric median over X, scaled by 1/blocksize. - - Notes - ----- - This function is noticeably faster if the length of the data is divisible - by the block size. - """ - if blocksize <= 0: - raise ValueError("blocksize must be a positive integer") - if blocksize == 1: - # No blocking needed - return geometric_median(X, tol=tol, y=y, max_iter=max_iter) - - o, v = X.shape # #observations & #variables - if o == 0: - # Handle empty input case - return np.full((v,), np.nan) - - r = o % blocksize # #remainder in last block - b = o // blocksize # #full blocks - - if b > 0: - # Process full blocks - # Reshape to (num_blocks, blocksize, num_variables) and sum along axis 1 - X_blocks = X[: o - r, :].reshape(b, blocksize, v).sum(axis=1) - if r > 0: - # Process remainder block if it exists - X_rem = X[o - r :, :].sum(axis=0, keepdims=True) * (blocksize / r) - # Combine full blocks and scaled remainder - X_processed = np.vstack((X_blocks, X_rem)) - else: - # Only full blocks - X_processed = X_blocks - elif r > 0: - # Only a remainder block exists - X_processed = X[o - r :, :].sum(axis=0, keepdims=True) * (blocksize / r) - else: - # This case should ideally not be reached if o > 0, but handle defensively - return np.full((v,), np.nan) - - # Call the standard geometric median function on the processed data - median_val = geometric_median(X_processed, tol=tol, y=y, max_iter=max_iter) - - # Scale the result by 1/blocksize as per MATLAB implementation - return median_val / blocksize - - -def mad(X, axis=0, keepdims=False): - """Calculate the median absolute deviation from the median along a given axis. - - Args: - X : array-like - Input data array. - axis : int, optional - Axis along which to compute the median absolute deviation. - Default is 0. - keepdims : bool, optional - If True, the result will have the same dimensions as X, - but with the specified axis having size 1. - Default is False. - - Returns - ------- - array-like: - Median absolute deviation of the input data. - """ - med = np.median(X, axis=axis, keepdims=True) - return np.median(np.abs(X - med), axis=axis, keepdims=keepdims) diff --git a/src/eegprep/plugins/clean_rawdata/vis_artifacts.py b/src/eegprep/plugins/clean_rawdata/vis_artifacts.py deleted file mode 100644 index cdfd2205..00000000 --- a/src/eegprep/plugins/clean_rawdata/vis_artifacts.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Diagnostics and browser display for clean_rawdata artifacts.""" - -from __future__ import annotations - -import logging -from typing import Any - -import numpy as np - -from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.sigprocfunc.eegplot import eegplot -from eegprep.plugins.clean_rawdata.private.masks import mask_to_intervals - - -logger = logging.getLogger(__name__) - -CLEAN_RAWDATA_REJECT_COLOR = (1.0, 0.95, 0.25) - - -def vis_artifacts(clean_eeg: dict[str, Any], original_eeg: dict[str, Any] | None = None, *, show: bool = True) -> Any: - """Compare a cleaned dataset against its original clean_rawdata source. - - Args: - clean_eeg: Cleaned EEG dictionary, usually returned by ``clean_artifacts``. - original_eeg: Original EEG dictionary. Defaults to ``clean_eeg`` when omitted. - show: Open an ``eegplot`` browser when true; otherwise return diagnostics only. - - Returns: - Diagnostics when ``show`` is false. When ``show`` is true, returns the - ``eegplot`` window/model object if it can be opened, otherwise ``None``. - """ - diagnostics = vis_artifacts_diagnostics(clean_eeg, original_eeg) - if not show: - return diagnostics - - source = diagnostics["original_eeg"] - try: - return eegplot( - source, - winrej=diagnostics["winrej"], - wincolor=CLEAN_RAWDATA_REJECT_COLOR, - events=source.get("event", []), - xgrid="off", - title=f"Rejected data highlighted -- eegplot() -- {source.get('setname', '')}".rstrip(), - ) - except RuntimeError as exc: - logger.info("Could not open clean_rawdata rejected-data browser: %s", exc) - return None - - -def vis_artifacts_diagnostics( - clean_eeg: dict[str, Any], - original_eeg: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Return clean_rawdata sample/channel rejection diagnostics.""" - source = original_eeg if original_eeg is not None else clean_eeg - raw_sample_mask = _mask(clean_eeg, "clean_sample_mask") - raw_channel_mask = _mask(clean_eeg, "clean_channel_mask") - original_samples = _pnts(source) - if original_eeg is None and raw_sample_mask.size: - original_samples = max(original_samples, raw_sample_mask.size) - clean_samples = _pnts(clean_eeg) - original_channels = _nbchan(source) - if original_eeg is None and raw_channel_mask.size: - original_channels = max(original_channels, raw_channel_mask.size) - clean_channels = _nbchan(clean_eeg) - sample_mask = _sample_mask(raw_sample_mask, original_samples) - rejected_intervals = mask_to_intervals(sample_mask, value=False) - rejected_sample_count = int(np.count_nonzero(~sample_mask)) - channel_mask = _channel_mask(raw_channel_mask, original_channels) - removed_indices = [index + 1 for index, keep in enumerate(channel_mask) if not keep] - removed_labels = _removed_channel_labels(source, removed_indices) - return { - "original_eeg": source, - "clean_eeg": clean_eeg, - "original_samples": original_samples, - "clean_samples": clean_samples, - "original_channels": original_channels, - "clean_channels": clean_channels, - "sample_mask": sample_mask, - "rejected_intervals": rejected_intervals, - "rejected_sample_count": rejected_sample_count, - "rejected_fraction": rejected_sample_count / original_samples if original_samples else 0.0, - "channel_mask": channel_mask, - "removed_channel_indices": removed_indices, - "removed_channel_labels": removed_labels, - "removed_channel_count": len(removed_indices), - "winrej": clean_rawdata_winrej(rejected_intervals, original_channels), - } - - -def clean_rawdata_winrej(rejected_intervals: np.ndarray, nbchan: int) -> np.ndarray: - """Convert rejected intervals to an EEGLAB ``eegplot`` ``winrej`` matrix.""" - rejected = np.asarray(rejected_intervals, dtype=float) - if rejected.size == 0: - return np.zeros((0, 5 + int(nbchan)), dtype=float) - rows = np.zeros((rejected.shape[0], 5 + int(nbchan)), dtype=float) - rows[:, 0:2] = rejected[:, 0:2] - rows[:, 2:5] = np.asarray(CLEAN_RAWDATA_REJECT_COLOR, dtype=float) - rows[:, 5:] = 1 - return rows - - -def _mask(clean_eeg: dict[str, Any], key: str) -> np.ndarray: - return np.asarray((clean_eeg.get("etc") or {}).get(key, []), dtype=bool).ravel() - - -def _sample_mask(mask: np.ndarray, original_samples: int) -> np.ndarray: - if mask.size == original_samples: - return mask.copy() - return np.ones(original_samples, dtype=bool) - - -def _channel_mask(mask: np.ndarray, original_channels: int) -> np.ndarray: - if mask.size == original_channels: - return mask.copy() - return np.ones(original_channels, dtype=bool) - - -def _removed_channel_labels(source: dict[str, Any], removed_indices: list[int]) -> list[str]: - chanlocs = _chanlocs(source) - labels = [] - for index in removed_indices: - if 1 <= index <= len(chanlocs): - labels.append(str(chanlocs[index - 1].get("labels", index))) - else: - labels.append(str(index)) - return labels - - -def _pnts(eeg: dict[str, Any]) -> int: - if eeg.get("pnts") is not None: - return int(eeg["pnts"]) - data = np.asarray(eeg.get("data")) - return int(data.shape[1]) if data.ndim >= 2 else 0 - - -def _nbchan(eeg: dict[str, Any]) -> int: - if eeg.get("nbchan") is not None: - return int(eeg["nbchan"]) - data = np.asarray(eeg.get("data")) - return int(data.shape[0]) if data.ndim >= 2 else 0 - - -def _chanlocs(eeg: dict[str, Any]) -> list[dict[str, Any]]: - return [chan if isinstance(chan, dict) else {} for chan in chanlocs_as_list(eeg.get("chanlocs"))] diff --git a/tests/compare_iclabel_engines.py b/tests/compare_iclabel_engines.py deleted file mode 100644 index b24106f0..00000000 --- a/tests/compare_iclabel_engines.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import sys -import numpy as np -import matplotlib.pyplot as plt -import argparse -from scipy.stats import pearsonr - -# ### Using ICLabel with Different Engines - -# ```python -# from eegprep import pop_loadset, iclabel - -# # Load an EEG file -# EEG = pop_loadset('./sample_data/eeglab_data_with_ica_tmp.set') - -# # Apply ICLabel with the default Python implementation -# EEG_python = iclabel(EEG, algorithm='default', engine=None) -# EEG_matlab = iclabel(EEG, algorithm='default', engine='matlab') -# ``` - -# ### Running the Comparison Script - -# The `test_iclabel_engines.py` script can be used to compare the results of applying ICLabel with different engines: - -# ```bash -# python test_iclabel_engines.py your_eeg_file.set --output_dir results --algorithm default -# ``` - - -# Add the src directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'src')) - -from eegprep import pop_loadset, pop_saveset, iclabel - - -def compare_iclabel_engines(input_file, output_dir=None, engines=None, algorithm='default'): - """ - Compare the results of applying ICLabel with different engines. - - Parameters: - ----------- - input_file : str - Path to the input EEG file - output_dir : str or None - Directory to save output files. If None, files are not saved. - engines : list or None - List of engines to compare. If None, all available engines are used. - Options are: None (Python), 'matlab', 'octave' - algorithm : str - Algorithm to use for classification, passed to the MATLAB/Octave implementation. - - Returns: - -------- - results : dict - Dictionary containing the results of each engine - """ - if engines is None: - engines = [ - None, # Python implementation - 'matlab', - ] - - # Load the input EEG file - print(f"Loading EEG file: {input_file}") - EEG = pop_loadset(input_file) - - # Apply ICLabel with different engines - results = {} - for engine in engines: - try: - engine_name = engine if engine is not None else 'python' - print(f"Applying ICLabel with engine={engine_name}, algorithm={algorithm}") - EEG_result = iclabel(EEG.copy(), algorithm=algorithm, engine=engine) - - # Store the result - results[engine_name] = EEG_result - - # Save the result if output_dir is specified - if output_dir is not None: - output_file = os.path.join(output_dir, f"iclabel_{engine_name}_{algorithm}.set") - pop_saveset(EEG_result, output_file) - print(f"Saved result to: {output_file}") - except Exception as e: - print(f"Error applying ICLabel with engine={engine_name}, algorithm={algorithm}: {e}") - - return results - - -def compare_classifications(results): - """ - Compare the classifications from different engines. - - Parameters: - ----------- - results : dict - Dictionary containing the results of each engine - - Returns: - -------- - comparisons : dict - Dictionary containing the comparison metrics - """ - engines = list(results.keys()) - n_engines = len(engines) - - comparisons = {} - - # Compare each pair of engines - for i in range(n_engines): - for j in range(i + 1, n_engines): - engine1 = engines[i] - engine2 = engines[j] - - # Get the classifications - classifications1 = results[engine1]['etc']['ic_classification']['ICLabel']['classifications'] - classifications2 = results[engine2]['etc']['ic_classification']['ICLabel']['classifications'] - - # Calculate correlation - correlations = [] - for comp_idx in range(classifications1.shape[0]): - corr, _ = pearsonr(classifications1[comp_idx], classifications2[comp_idx]) - correlations.append(corr) - - # Calculate mean absolute difference - mean_abs_diff = np.mean(np.abs(classifications1 - classifications2)) - - # Store the comparison metrics - comparisons[(engine1, engine2)] = { - 'correlations': correlations, - 'mean_correlation': np.mean(correlations), - 'mean_abs_diff': mean_abs_diff, - } - - return comparisons - - -def plot_comparisons(results, comparisons, output_dir=None, algorithm='default'): - """ - Plot the comparisons between different engines. - - Parameters: - ----------- - results : dict - Dictionary containing the results of each engine - comparisons : dict - Dictionary containing the comparison metrics - output_dir : str or None - Directory to save output files. If None, files are not saved. - algorithm : str - Algorithm used for classification - """ - engines = list(results.keys()) - n_engines = len(engines) - - # Plot the classifications for each engine - plt.figure(figsize=(15, 5 * n_engines)) - - for i, engine in enumerate(engines): - plt.subplot(n_engines, 1, i + 1) - classifications = results[engine]['etc']['ic_classification']['ICLabel']['classifications'] - classes = results[engine]['etc']['ic_classification']['ICLabel']['classes'] - - plt.imshow(classifications, aspect='auto', cmap='viridis') - plt.colorbar(label='Probability') - plt.yticks(range(len(classes)), classes) - plt.title(f"ICLabel Classifications - Engine: {engine}, Algorithm: {algorithm}") - plt.xlabel('Component') - - plt.tight_layout() - - if output_dir is not None: - plt.savefig(os.path.join(output_dir, f'iclabel_classifications_{algorithm}.png')) - - # Plot the correlations between engines - plt.figure(figsize=(15, 5 * len(comparisons))) - - for i, ((engine1, engine2), metrics) in enumerate(comparisons.items()): - plt.subplot(len(comparisons), 1, i + 1) - plt.bar(range(len(metrics['correlations'])), metrics['correlations']) - plt.axhline( - metrics['mean_correlation'], color='r', linestyle='--', label=f"Mean: {metrics['mean_correlation']:.3f}" - ) - plt.title(f"Correlation between {engine1} and {engine2}") - plt.xlabel('Component') - plt.ylabel('Correlation') - plt.legend() - - plt.tight_layout() - - if output_dir is not None: - plt.savefig(os.path.join(output_dir, f'iclabel_correlations_{algorithm}.png')) - - -def main(): - """ - Main function to run the comparison. - """ - parser = argparse.ArgumentParser(description='Compare ICLabel engines') - parser.add_argument('input_file', help='Path to the input EEG file') - parser.add_argument('--output_dir', help='Directory to save output files') - parser.add_argument('--algorithm', default='default', help='Algorithm to use for classification') - args = parser.parse_args() - - # Create output directory if it doesn't exist - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - # Compare ICLabel engines - results = compare_iclabel_engines(args.input_file, args.output_dir, algorithm=args.algorithm) - - # Compare classifications - if len(results) > 1: - comparisons = compare_classifications(results) - - # Plot comparisons - plot_comparisons(results, comparisons, args.output_dir, args.algorithm) - - # Wait for user input before closing the plot - plt.show() - - # Print summary - print("\nSummary of comparisons:") - for (engine1, engine2), metrics in comparisons.items(): - print(f"Comparison between {engine1} and {engine2}:") - print(f" Mean correlation: {metrics['mean_correlation']:.3f}") - print(f" Mean absolute difference: {metrics['mean_abs_diff']:.3f}") - else: - print("Not enough results to compare.") - - -if __name__ == '__main__': - main() diff --git a/tests/test_ICL_feature_extractor.py b/tests/test_ICL_feature_extractor.py deleted file mode 100644 index 67ea460a..00000000 --- a/tests/test_ICL_feature_extractor.py +++ /dev/null @@ -1,690 +0,0 @@ -"""Tests for ICL_feature_extractor module. - -This module tests the ICL feature extraction functionality including -topographic plots, power spectral density, autocorrelation features, -and various edge cases. -""" - -import unittest -import numpy as np -import os -import tempfile -import scipy.io - -from eegprep.plugins.ICLabel.ICL_feature_extractor import ICL_feature_extractor -from eegprep.plugins.ICLabel.eeg_rpsd import eeg_rpsd -from eegprep.functions.popfunc.pop_loadset import pop_loadset -from eegprep.functions.popfunc.pop_saveset import pop_saveset -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab - -from tests.fixtures import create_test_eeg - -local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - - -class TestICLFeatureExtractorBasic(unittest.TestCase): - """Test basic ICL_feature_extractor functionality.""" - - def setUp(self): - """Set up test fixtures with synthetic EEG data.""" - np.random.seed(42) # For reproducible tests - - # Create synthetic EEG data with ICA components - self.n_channels = 32 - self.n_components = 8 - self.n_samples = 1000 - self.srate = 250.0 - - # Create base EEG structure - self.test_eeg = create_test_eeg( - n_channels=self.n_channels, n_samples=self.n_samples, srate=self.srate, n_trials=1 - ) - - # Add required ICA fields - self.test_eeg['icawinv'] = np.random.randn(self.n_channels, self.n_components) * 0.5 - self.test_eeg['icaweights'] = np.linalg.pinv(self.test_eeg['icawinv']) - self.test_eeg['icasphere'] = np.eye(self.n_channels) - self.test_eeg['icaact'] = np.random.randn(self.n_components, self.n_samples, 1) * 0.5 - self.test_eeg['icachansind'] = np.arange(self.n_channels) - self.test_eeg['ref'] = 'averef' - - # Add channel locations (simplified) - create numpy array format - self.test_eeg['chanlocs'] = np.array( - [ - { - 'theta': 45 * i, # Use fixed positions to avoid randomness issues - 'radius': 0.3, - 'X': 0.3 * np.cos(np.radians(45 * i)), - 'Y': 0.3 * np.sin(np.radians(45 * i)), - 'Z': 0.0, - 'labels': f'Ch{i + 1}', - } - for i in range(self.n_channels) - ] - ) - - def test_icl_feature_extractor_missing_ica_winv(self): - """Missing icawinv raises a clear ValueError before any dereference.""" - EEG = self.test_eeg.copy() - del EEG['icawinv'] - - with self.assertRaises(ValueError) as cm: - ICL_feature_extractor(EEG) - self.assertIn('ICA decomposition', str(cm.exception)) - - def test_icl_feature_extractor_empty_ica_winv(self): - """Empty icawinv raises a clear ValueError before any dereference.""" - EEG = self.test_eeg.copy() - EEG['icawinv'] = np.array([]) - - with self.assertRaises(ValueError) as cm: - ICL_feature_extractor(EEG) - self.assertIn('ICA decomposition', str(cm.exception)) - - def test_icl_feature_extractor_missing_ref_field(self): - """A dataset without a 'ref' field is treated as non-average and re-referenced.""" - EEG = self.test_eeg.copy() - del EEG['ref'] - - # Must not raise KeyError on the missing 'ref'; should proceed to feature extraction. - features = ICL_feature_extractor(EEG, flag_autocorr=False) - self.assertEqual(len(features), 2) - - def test_icl_feature_extractor_missing_icaact(self): - """Test ICL_feature_extractor with missing icaact.""" - EEG = self.test_eeg.copy() - EEG['icaact'] = None - - with self.assertRaises(ValueError) as cm: - ICL_feature_extractor(EEG) - self.assertIn('You must have ICA activations', str(cm.exception)) - - def test_icl_feature_extractor_basic_functionality(self): - """Test basic ICL_feature_extractor functionality.""" - features = ICL_feature_extractor(self.test_eeg, flag_autocorr=False) - - # Should return 2 features (topo and psd) when flag_autocorr=False - self.assertEqual(len(features), 2) - - # Check topo features - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) - self.assertEqual(topo.dtype, np.float32) - self.assertTrue(np.all(np.abs(topo) <= 0.99)) # Should be scaled by 0.99 - - # Check psd features - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) - self.assertEqual(psd.dtype, np.float32) - self.assertTrue(np.all(np.abs(psd) <= 0.99)) # Should be scaled by 0.99 - - def test_icl_feature_extractor_with_autocorr(self): - """Test ICL_feature_extractor with autocorrelation features.""" - features = ICL_feature_extractor(self.test_eeg, flag_autocorr=True) - - # Should return 3 features (topo, psd, autocorr) when flag_autocorr=True - self.assertEqual(len(features), 3) - - # Check topo features - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) - self.assertEqual(topo.dtype, np.float32) - - # Check psd features - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) - self.assertEqual(psd.dtype, np.float32) - - # Check autocorr features - autocorr = features[2] - self.assertEqual(autocorr.ndim, 4) # Should be 4D - self.assertEqual(autocorr.dtype, np.float32) - self.assertEqual(autocorr.shape[3], self.n_components) # Last dimension should be n_components - self.assertTrue(np.all(np.abs(autocorr) <= 0.99)) # Should be scaled by 0.99 - - -class TestICLFeatureExtractorDataTypes(unittest.TestCase): - """Test ICL_feature_extractor with different data types.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - - self.n_channels = 16 # Smaller for faster testing - self.n_components = 4 - self.n_samples = 500 - self.srate = 250.0 - - self.base_eeg = create_test_eeg( - n_channels=self.n_channels, n_samples=self.n_samples, srate=self.srate, n_trials=1 - ) - - self.base_eeg['icawinv'] = np.random.randn(self.n_channels, self.n_components) * 0.5 - self.base_eeg['icaweights'] = np.linalg.pinv(self.base_eeg['icawinv']) - self.base_eeg['icasphere'] = np.eye(self.n_channels) - self.base_eeg['icaact'] = np.random.randn(self.n_components, self.n_samples, 1) * 0.5 - self.base_eeg['icachansind'] = np.arange(self.n_channels) - self.base_eeg['ref'] = 'averef' - - # Add channel locations - self.base_eeg['chanlocs'] = [] - for i in range(self.n_channels): - self.base_eeg['chanlocs'].append( - { - 'theta': np.random.uniform(0, 360), - 'radius': np.random.uniform(0.1, 0.5), - 'X': np.random.uniform(-1, 1), - 'Y': np.random.uniform(-1, 1), - 'Z': np.random.uniform(-1, 1), - 'labels': f'Ch{i + 1}', - } - ) - - def test_icl_feature_extractor_float32_data(self): - """Test ICL_feature_extractor with float32 input data.""" - EEG = self.base_eeg.copy() - EEG['icaact'] = EEG['icaact'].astype(np.float32) - - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work and return float32 features - self.assertEqual(len(features), 2) - for feature in features: - self.assertEqual(feature.dtype, np.float32) - - def test_icl_feature_extractor_float64_data(self): - """Test ICL_feature_extractor with float64 input data.""" - EEG = self.base_eeg.copy() - EEG['icaact'] = EEG['icaact'].astype(np.float64) - - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work and return float32 features (converted internally) - self.assertEqual(len(features), 2) - for feature in features: - self.assertEqual(feature.dtype, np.float32) - - -class TestICLFeatureExtractorEdgeCases(unittest.TestCase): - """Test ICL_feature_extractor edge cases.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - - self.n_channels = 8 - self.n_components = 3 - self.n_samples = 250 # 1 second at 250 Hz - self.srate = 250.0 - - self.base_eeg = create_test_eeg( - n_channels=self.n_channels, n_samples=self.n_samples, srate=self.srate, n_trials=1 - ) - - self.base_eeg['icawinv'] = np.random.randn(self.n_channels, self.n_components) * 0.5 - self.base_eeg['icaweights'] = np.linalg.pinv(self.base_eeg['icawinv']) - self.base_eeg['icasphere'] = np.eye(self.n_channels) - self.base_eeg['icaact'] = np.random.randn(self.n_components, self.n_samples, 1) * 0.5 - self.base_eeg['icachansind'] = np.arange(self.n_channels) - self.base_eeg['ref'] = 'averef' - - # Add minimal channel locations - self.base_eeg['chanlocs'] = [] - for i in range(self.n_channels): - self.base_eeg['chanlocs'].append( - { - 'theta': i * 45, # Spread evenly - 'radius': 0.3, - 'X': 0.3 * np.cos(np.radians(i * 45)), - 'Y': 0.3 * np.sin(np.radians(i * 45)), - 'Z': 0.0, - 'labels': f'Ch{i + 1}', - } - ) - - def test_icl_feature_extractor_small_eeg_data(self): - """Test ICL_feature_extractor with small EEG data.""" - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - - # Should work with small data - self.assertEqual(len(features), 2) - - # Check feature dimensions - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) - - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) - - def test_icl_feature_extractor_single_component(self): - """Test ICL_feature_extractor with single ICA component.""" - EEG = self.base_eeg.copy() - EEG['icawinv'] = EEG['icawinv'][:, :1] # Keep only first component - EEG['icaweights'] = EEG['icaweights'][:1, :] # Keep only first component - EEG['icaact'] = EEG['icaact'][:1, :, :] # Keep only first component - - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work with single component - self.assertEqual(len(features), 2) - - # Check feature dimensions - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, 1)) # 1 component - - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, 1)) # 1 component - - def test_icl_feature_extractor_many_components(self): - """Test ICL_feature_extractor with many ICA components.""" - EEG = self.base_eeg.copy() - n_many_components = 16 - - # Create many components - EEG['icawinv'] = np.random.randn(self.n_channels, n_many_components) * 0.5 - EEG['icaweights'] = np.linalg.pinv(EEG['icawinv']) - EEG['icaact'] = np.random.randn(n_many_components, self.n_samples, 1) * 0.5 - - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work with many components - self.assertEqual(len(features), 2) - - # Check feature dimensions - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, n_many_components)) - - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, n_many_components)) - - def test_icl_feature_extractor_very_short_data(self): - """Test ICL_feature_extractor with short data (minimum for 100 freq bins).""" - EEG = self.base_eeg.copy() - - # Use short data - need at least ~200 samples for 100 frequency bins - short_samples = 200 # 0.8 seconds at 250 Hz - EEG['icaact'] = EEG['icaact'][:, :short_samples, :] - EEG['data'] = EEG['data'][:, :short_samples] - EEG['pnts'] = short_samples - EEG['xmax'] = short_samples / self.srate - - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work with very short data - self.assertEqual(len(features), 2) - - # Features should still have expected shapes - topo = features[0] - self.assertEqual(topo.shape[0:3], (32, 32, 1)) - - psd = features[1] - self.assertEqual(psd.shape[0:3], (1, 100, 1)) - - def test_icl_feature_extractor_autocorr_path_selection(self): - """Test ICL_feature_extractor autocorr path selection based on data length.""" - # Test short data (< 5 seconds) - should use eeg_autocorr - short_eeg = self.base_eeg.copy() - short_pnts = int(3 * self.srate) # 3 seconds = 750 samples - short_eeg['pnts'] = short_pnts - short_eeg['icaact'] = np.random.randn(self.n_components, short_pnts, 1) * 0.5 - short_eeg['data'] = np.random.randn(self.n_channels, short_pnts) * 0.5 - short_eeg['xmax'] = 3.0 - short_eeg['times'] = np.arange(short_pnts) / self.srate - - features = ICL_feature_extractor(short_eeg, flag_autocorr=True) - self.assertEqual(len(features), 3) # Should include autocorr - - # Test long data (> 5 seconds) - should use eeg_autocorr_welch - long_eeg = self.base_eeg.copy() - long_pnts = int(6 * self.srate) # 6 seconds = 1500 samples - long_eeg['pnts'] = long_pnts - long_eeg['icaact'] = np.random.randn(self.n_components, long_pnts, 1) * 0.5 - long_eeg['data'] = np.random.randn(self.n_channels, long_pnts) * 0.5 - long_eeg['xmax'] = 6.0 - long_eeg['times'] = np.arange(long_pnts) / self.srate - - features = ICL_feature_extractor(long_eeg, flag_autocorr=True) - self.assertEqual(len(features), 3) # Should include autocorr - - def test_icl_feature_extractor_multi_trial_data(self): - """Test ICL_feature_extractor with multi-trial data.""" - EEG = self.base_eeg.copy() - - # Convert to multi-trial data - n_trials = 3 - EEG['trials'] = n_trials - EEG['icaact'] = np.random.randn(self.n_components, self.n_samples, n_trials) * 0.5 - EEG['data'] = np.random.randn(self.n_channels, self.n_samples, n_trials) * 0.5 - - features = ICL_feature_extractor(EEG, flag_autocorr=True) - - # Should work with multi-trial data and use eeg_autocorr_fftw - self.assertEqual(len(features), 3) - - # Check that features have correct component dimension - for feature in features: - self.assertEqual(feature.shape[3], self.n_components) - - -class TestICLFeatureExtractorValidation(unittest.TestCase): - """Test ICL_feature_extractor validation and error handling.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - - self.n_channels = 8 - self.n_components = 4 - self.n_samples = 500 - self.srate = 250.0 - - self.base_eeg = create_test_eeg( - n_channels=self.n_channels, n_samples=self.n_samples, srate=self.srate, n_trials=1 - ) - - self.base_eeg['icawinv'] = np.random.randn(self.n_channels, self.n_components) * 0.5 - self.base_eeg['icaweights'] = np.linalg.pinv(self.base_eeg['icawinv']) - self.base_eeg['icasphere'] = np.eye(self.n_channels) - self.base_eeg['icaact'] = np.random.randn(self.n_components, self.n_samples, 1) * 0.5 - self.base_eeg['icachansind'] = np.arange(self.n_channels) - self.base_eeg['ref'] = 'averef' - - # Add channel locations - self.base_eeg['chanlocs'] = [] - for i in range(self.n_channels): - self.base_eeg['chanlocs'].append( - { - 'theta': i * 45, - 'radius': 0.3, - 'X': 0.3 * np.cos(np.radians(i * 45)), - 'Y': 0.3 * np.sin(np.radians(i * 45)), - 'Z': 0.0, - 'labels': f'Ch{i + 1}', - } - ) - - def test_icl_feature_extractor_no_inf_nan_in_features(self): - """Test ICL_feature_extractor produces no inf/nan values in features.""" - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - - # Check that no features contain inf or nan - for i, feature in enumerate(features): - self.assertTrue(np.all(np.isfinite(feature)), f"Feature {i} contains inf or nan values") - - def test_icl_feature_extractor_deterministic_seed(self): - """Test ICL_feature_extractor produces consistent results with same seed.""" - # Set seed and extract features - np.random.seed(123) - features1 = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - - # Reset seed and extract features again - np.random.seed(123) - features2 = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - - # Results should be identical (at least for the deterministic parts) - # Note: Some randomness may come from internal functions, so we check structure - self.assertEqual(len(features1), len(features2)) - for i in range(len(features1)): - self.assertEqual(features1[i].shape, features2[i].shape) - self.assertEqual(features1[i].dtype, features2[i].dtype) - - def test_icl_feature_extractor_ref_not_average(self): - """Test ICL_feature_extractor with reference not set to average.""" - EEG = self.base_eeg.copy() - EEG['ref'] = 'Cz' # Not average reference - - # Should still work (function re-references internally) - features = ICL_feature_extractor(EEG, flag_autocorr=False) - self.assertEqual(len(features), 2) - - def test_icl_feature_extractor_mismatched_icachansind(self): - """Test ICL_feature_extractor with mismatched icachansind.""" - EEG = self.base_eeg.copy() - - # Make icachansind mismatch the number of channels - EEG['icachansind'] = np.arange(self.n_channels - 2) # Fewer indices - - try: - # May work or fail depending on implementation - test graceful handling - features = ICL_feature_extractor(EEG, flag_autocorr=False) - if features: - self.assertEqual(len(features), 2) - except (ValueError, IndexError): - # Expected behavior for mismatched indices - pass - - def test_icl_feature_extractor_feature_scaling(self): - """Test ICL_feature_extractor feature scaling (should be scaled by 0.99).""" - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=True) - - # All features should be scaled by 0.99 (max absolute value <= 0.99) - for i, feature in enumerate(features): - max_abs_val = np.max(np.abs(feature)) - self.assertLessEqual(max_abs_val, 0.99 + 1e-6, f"Feature {i} not properly scaled by 0.99") - - def test_icl_feature_extractor_psd_length_extrapolation(self): - """Test ICL_feature_extractor PSD length handling (should be 100 frequencies).""" - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - - # PSD should always have 100 frequency bins (extrapolated if needed) - psd = features[1] - self.assertEqual(psd.shape[1], 100, "PSD should have exactly 100 frequency bins") - - -class TestEegRpsdGlobalRng(unittest.TestCase): - """Regression tests that eeg_rpsd never mutates the global numpy RNG.""" - - def setUp(self): - np.random.seed(42) - n_channels = 16 - n_components = 4 - n_samples = 500 - srate = 250.0 - eeg = create_test_eeg(n_channels=n_channels, n_samples=n_samples, srate=srate, n_trials=1) - eeg['icawinv'] = np.random.randn(n_channels, n_components) * 0.5 - eeg['icaweights'] = np.linalg.pinv(eeg['icawinv']) - eeg['icasphere'] = np.eye(n_channels) - eeg['icaact'] = np.random.randn(n_components, n_samples, 1) * 0.5 - eeg['icachansind'] = np.arange(n_channels) - self.eeg = eeg - - def test_does_not_mutate_global_rng(self): - """eeg_rpsd must use a local RNG, leaving np.random's global state intact.""" - np.random.seed(123) - before = np.random.get_state() - eeg_rpsd(self.eeg) - after = np.random.get_state() - - self.assertEqual(before[0], after[0]) - self.assertTrue(np.array_equal(before[1], after[1])) - self.assertEqual(before[2:], after[2:]) - - def test_output_is_deterministic(self): - """eeg_rpsd must return identical output regardless of global RNG state.""" - np.random.seed(1) - psd_a = eeg_rpsd(self.eeg) - np.random.seed(999) - np.random.rand(37) # perturb the global RNG between calls - psd_b = eeg_rpsd(self.eeg) - - self.assertTrue(np.array_equal(psd_a, psd_b)) - - -class TestICLFeatureExtractorParity(unittest.TestCase): - """Test parity between Python and MATLAB ICL_feature_extractor implementations.""" - - def setUp(self): - """Set up test fixtures.""" - # Try to get MATLAB engine - try: - self.eeglab = get_eeglab('MAT', auto_file_roundtrip=False) - self.matlab_available = True - except Exception as e: - self.matlab_available = False - self.skipTest(f"MATLAB not available: {e}") - - # Load real EEG dataset with ICA - test_file = os.path.join(local_url, 'eeglab_data_with_ica_tmp.set') - self.EEG = pop_loadset(test_file) - # Set ref to 'averef' to skip re-referencing (Python's pop_reref differs from MATLAB's) - self.EEG['ref'] = 'averef' - - def test_parity_full_feature_extraction(self): - """Test parity with MATLAB for complete feature extraction.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - features_py = ICL_feature_extractor(self.EEG.copy(), True) - - # MATLAB result - use file roundtrip for cell array output - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, true); - save('{temp_file}.mat', 'features'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB features - mat_data = scipy.io.loadmat(temp_file + '.mat') - features_ml = [mat_data['features'][0, 0], mat_data['features'][0, 1], mat_data['features'][0, 2]] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare all three features - feature_names = ['Topo', 'PSD', 'Autocorr'] - - for i, name in enumerate(feature_names): - py_feat = features_py[i] - ml_feat = features_ml[i] - - # Verify shapes match - self.assertEqual( - py_feat.shape, ml_feat.shape, f"{name} feature shape mismatch: {py_feat.shape} vs {ml_feat.shape}" - ) - - # Compare values - # Max absolute diff: ~6e-8 (float32 precision) - np.testing.assert_allclose( - py_feat, ml_feat, rtol=1e-5, atol=1e-6, err_msg=f"{name} feature differs beyond tolerance" - ) - - def test_parity_topo_feature_only(self): - """Test parity specifically for topography feature.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - features_py = ICL_feature_extractor(self.EEG.copy(), True) - topo_py = features_py[0] - - # MATLAB result - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, true); - topo = features{{1}}; - save('{temp_file}.mat', 'topo'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - topo_ml = mat_data['topo'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare - # Max absolute diff: ~6e-8 (float32 precision) - np.testing.assert_allclose( - topo_py, topo_ml, rtol=1e-5, atol=1e-6, err_msg="Topo feature differs beyond tolerance" - ) - - def test_parity_psd_feature_only(self): - """Test parity specifically for PSD feature.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - features_py = ICL_feature_extractor(self.EEG.copy(), True) - psd_py = features_py[1] - - # MATLAB result - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, true); - psd = features{{2}}; - save('{temp_file}.mat', 'psd'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - psd_ml = mat_data['psd'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare - # Max absolute diff: ~6e-8 (float32 precision) - np.testing.assert_allclose(psd_py, psd_ml, rtol=1e-5, atol=1e-6, err_msg="PSD feature differs beyond tolerance") - - def test_parity_autocorr_feature_only(self): - """Test parity specifically for autocorrelation feature.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - features_py = ICL_feature_extractor(self.EEG.copy(), True) - autocorr_py = features_py[2] - - # MATLAB result - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, true); - autocorr = features{{3}}; - save('{temp_file}.mat', 'autocorr'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - autocorr_ml = mat_data['autocorr'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare - # Max absolute diff: ~6e-8 (float32 precision) - np.testing.assert_allclose( - autocorr_py, autocorr_ml, rtol=1e-5, atol=1e-6, err_msg="Autocorr feature differs beyond tolerance" - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_ICL_feature_extractor_parity.py b/tests/test_ICL_feature_extractor_parity.py deleted file mode 100644 index a5129340..00000000 --- a/tests/test_ICL_feature_extractor_parity.py +++ /dev/null @@ -1,242 +0,0 @@ -""" -Test parity between Python and MATLAB implementations of ICL_feature_extractor. - -This test compares the Python implementation against the MATLAB/EEGLAB reference. -Multithreading is disabled for deterministic numerical results. -""" - -# Disable multithreading for deterministic numerical results -import os - -os.environ["OMP_NUM_THREADS"] = "1" -os.environ["MKL_NUM_THREADS"] = "1" -os.environ["NUMEXPR_NUM_THREADS"] = "1" -os.environ["OPENBLAS_NUM_THREADS"] = "1" -os.environ["VECLIB_MAXIMUM_THREADS"] = "1" - -import unittest -import numpy as np -import tempfile -import scipy.io -from eegprep import pop_loadset, pop_saveset, ICL_feature_extractor -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab - -local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - -ICLABEL_PARITY_RTOL = 2e-5 -ICLABEL_PARITY_ATOL = 1e-8 -# MATLAB ICLabel casts PSD features to single precision before returning them. -ICLABEL_PSD_PARITY_ATOL = 5e-8 - - -class TestICLFeatureExtractorParity(unittest.TestCase): - """Test parity between Python and MATLAB ICL_feature_extractor implementations.""" - - def setUp(self): - """Set up test fixtures.""" - # Try to get MATLAB engine - try: - self.eeglab = get_eeglab('MAT', auto_file_roundtrip=False) - self.matlab_available = True - except Exception as e: - self.matlab_available = False - self.skipTest(f"MATLAB not available: {e}") - - # Load real EEG dataset with ICA - test_file = os.path.join(local_url, 'eeglab_data_with_ica_tmp.set') - self.EEG = pop_loadset(test_file) - - def test_parity_without_autocorr(self): - """Test parity with MATLAB without autocorrelation (flag_autocorr=False).""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - py_features = ICL_feature_extractor(self.EEG.copy(), flag_autocorr=False) - - # MATLAB result - use file roundtrip - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, false); - topo = features{{1}}; - psd = features{{2}}; - save('{temp_file}.mat', 'topo', 'psd'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - ml_topo = mat_data['topo'] - ml_psd = mat_data['psd'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare topo feature - py_topo = py_features[0] - print("\nTopo comparison:") - print(f" Python shape: {py_topo.shape}") - print(f" MATLAB shape: {ml_topo.shape}") - print(f" Max absolute diff: {np.max(np.abs(py_topo - ml_topo)):.6f}") - - # Calculate mismatched elements - mismatch_mask = ~np.isclose(py_topo, ml_topo, rtol=ICLABEL_PARITY_RTOL, atol=ICLABEL_PARITY_ATOL) - n_mismatch = np.sum(mismatch_mask) - n_total = py_topo.size - print(f" Mismatched elements: {n_mismatch} / {n_total} ({100 * n_mismatch / n_total:.2f}%)") - - if n_mismatch > 0: - max_rel_diff = np.max(np.abs((py_topo - ml_topo) / (ml_topo + 1e-10))) - print(f" Max relative diff: {max_rel_diff:.6f}") - - np.testing.assert_allclose( - py_topo, - ml_topo, - rtol=ICLABEL_PARITY_RTOL, - atol=ICLABEL_PARITY_ATOL, - err_msg="Topo feature differs beyond tolerance", - ) - - # Compare psd feature - py_psd = py_features[1] - print("\nPSD comparison:") - print(f" Python shape: {py_psd.shape}") - print(f" MATLAB shape: {ml_psd.shape}") - print(f" Max absolute diff: {np.max(np.abs(py_psd - ml_psd)):.6f}") - - # Calculate mismatched elements - mismatch_mask = ~np.isclose(py_psd, ml_psd, rtol=ICLABEL_PARITY_RTOL, atol=ICLABEL_PSD_PARITY_ATOL) - n_mismatch = np.sum(mismatch_mask) - n_total = py_psd.size - print(f" Mismatched elements: {n_mismatch} / {n_total} ({100 * n_mismatch / n_total:.2f}%)") - - if n_mismatch > 0: - max_rel_diff = np.max(np.abs((py_psd - ml_psd) / (ml_psd + 1e-10))) - print(f" Max relative diff: {max_rel_diff:.6f}") - - np.testing.assert_allclose( - py_psd, - ml_psd, - rtol=ICLABEL_PARITY_RTOL, - atol=ICLABEL_PSD_PARITY_ATOL, - err_msg="PSD feature differs beyond tolerance", - ) - - def test_parity_with_autocorr(self): - """Test parity with MATLAB with autocorrelation (flag_autocorr=True).""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - py_features = ICL_feature_extractor(self.EEG.copy(), flag_autocorr=True) - - # MATLAB result - use file roundtrip - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, true); - topo = features{{1}}; - psd = features{{2}}; - autocorr = features{{3}}; - save('{temp_file}.mat', 'topo', 'psd', 'autocorr'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - ml_topo = mat_data['topo'] - ml_psd = mat_data['psd'] - ml_autocorr = mat_data['autocorr'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare topo feature - py_topo = py_features[0] - print("\nTopo comparison:") - print(f" Python shape: {py_topo.shape}") - print(f" MATLAB shape: {ml_topo.shape}") - print(f" Max absolute diff: {np.max(np.abs(py_topo - ml_topo)):.6f}") - - # Calculate mismatched elements - mismatch_mask = ~np.isclose(py_topo, ml_topo, rtol=ICLABEL_PARITY_RTOL, atol=ICLABEL_PARITY_ATOL) - n_mismatch = np.sum(mismatch_mask) - n_total = py_topo.size - print(f" Mismatched elements: {n_mismatch} / {n_total} ({100 * n_mismatch / n_total:.2f}%)") - - if n_mismatch > 0: - max_rel_diff = np.max(np.abs((py_topo - ml_topo) / (ml_topo + 1e-10))) - print(f" Max relative diff: {max_rel_diff:.6f}") - - np.testing.assert_allclose( - py_topo, - ml_topo, - rtol=ICLABEL_PARITY_RTOL, - atol=ICLABEL_PARITY_ATOL, - err_msg="Topo feature differs beyond tolerance", - ) - - # Compare psd feature - py_psd = py_features[1] - print("\nPSD comparison:") - print(f" Python shape: {py_psd.shape}") - print(f" MATLAB shape: {ml_psd.shape}") - print(f" Max absolute diff: {np.max(np.abs(py_psd - ml_psd)):.6f}") - - # Calculate mismatched elements - mismatch_mask = ~np.isclose(py_psd, ml_psd, rtol=ICLABEL_PARITY_RTOL, atol=ICLABEL_PSD_PARITY_ATOL) - n_mismatch = np.sum(mismatch_mask) - n_total = py_psd.size - print(f" Mismatched elements: {n_mismatch} / {n_total} ({100 * n_mismatch / n_total:.2f}%)") - - if n_mismatch > 0: - max_rel_diff = np.max(np.abs((py_psd - ml_psd) / (ml_psd + 1e-10))) - print(f" Max relative diff: {max_rel_diff:.6f}") - - np.testing.assert_allclose( - py_psd, - ml_psd, - rtol=ICLABEL_PARITY_RTOL, - atol=ICLABEL_PSD_PARITY_ATOL, - err_msg="PSD feature differs beyond tolerance", - ) - - # Compare autocorr feature - py_autocorr = py_features[2] - print("\nAutocorr comparison:") - print(f" Python shape: {py_autocorr.shape}") - print(f" MATLAB shape: {ml_autocorr.shape}") - print(f" Max absolute diff: {np.max(np.abs(py_autocorr - ml_autocorr)):.6f}") - - # Calculate mismatched elements - mismatch_mask = ~np.isclose(py_autocorr, ml_autocorr, rtol=ICLABEL_PARITY_RTOL, atol=ICLABEL_PARITY_ATOL) - n_mismatch = np.sum(mismatch_mask) - n_total = py_autocorr.size - print(f" Mismatched elements: {n_mismatch} / {n_total} ({100 * n_mismatch / n_total:.2f}%)") - - if n_mismatch > 0: - max_rel_diff = np.max(np.abs((py_autocorr - ml_autocorr) / (ml_autocorr + 1e-10))) - print(f" Max relative diff: {max_rel_diff:.6f}") - - np.testing.assert_allclose( - py_autocorr, - ml_autocorr, - rtol=ICLABEL_PARITY_RTOL, - atol=ICLABEL_PARITY_ATOL, - err_msg="Autocorr feature differs beyond tolerance", - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_clean_artifacts.py b/tests/test_clean_artifacts.py deleted file mode 100644 index 89d725fc..00000000 --- a/tests/test_clean_artifacts.py +++ /dev/null @@ -1,993 +0,0 @@ -""" -Test suite for clean_artifacts.py - All-in-one artifact removal. - -This module tests the clean_artifacts function that provides comprehensive -artifact removal including flatline channels, drifts, noisy channels, bursts, and windows. -""" - -import unittest -import sys -import numpy as np - -# Add src to path for imports -sys.path.insert(0, 'src') -from eegprep.plugins.clean_rawdata.clean_artifacts import clean_artifacts -from eegprep.plugins.clean_rawdata.pop_clean_rawdata import pop_clean_rawdata -from eegprep.utils.testing import DebuggableTestCase - -from tests.fixtures import create_test_eeg as _create_test_eeg - - -def create_test_eeg(): - """Continuous (2D) EEG fixture sized for clean_artifacts (20 s at 500 Hz).""" - return _create_test_eeg(n_channels=32, n_samples=10000, srate=500.0, n_trials=1) - - -class TestCleanArtifactsBasic(DebuggableTestCase): - """Basic test cases for clean_artifacts function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_basic_functionality(self): - """Test basic clean_artifacts functionality with default parameters.""" - EEG, HP, BUR, removed_channels = clean_artifacts(self.test_eeg) - - # Check that all return values are present - self.assertIsInstance(EEG, dict) - self.assertIsInstance(HP, dict) - self.assertIsInstance(BUR, dict) - self.assertIsInstance(removed_channels, np.ndarray) - - # Check that EEG structure is preserved - self.assertIn('data', EEG) - self.assertIn('srate', EEG) - self.assertIn('nbchan', EEG) - self.assertIn('pnts', EEG) - - # Check that data dimensions are reasonable - self.assertEqual(EEG['srate'], self.test_eeg['srate']) - self.assertGreaterEqual(EEG['nbchan'], 1) # At least one channel should remain - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - def test_clean_artifacts_all_off(self): - """Test clean_artifacts with all criteria disabled.""" - self.test_eeg.pop('etc') - original_keys = set(self.test_eeg) - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # With all criteria off, data should be unchanged - self.assertEqual(EEG['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(EEG['pnts'], self.test_eeg['pnts']) - np.testing.assert_array_equal(EEG['data'], self.test_eeg['data']) - self.assertEqual(set(self.test_eeg), original_keys) - - def test_clean_artifacts_invalid_highpass_string(self): - """Test clean_artifacts with invalid highpass string parameter.""" - with self.assertRaises(ValueError) as cm: - clean_artifacts(self.test_eeg, Highpass='invalid') - self.assertIn('Highpass must be a (low, high) tuple or None/"off"', str(cm.exception)) - - def test_clean_artifacts_invalid_highpass_single_value(self): - """Test clean_artifacts with single value instead of tuple.""" - with self.assertRaises(ValueError) as cm: - clean_artifacts(self.test_eeg, Highpass=0.5) - self.assertIn('Highpass must be a (low, high) tuple or None/"off"', str(cm.exception)) - - def test_clean_artifacts_invalid_highpass_too_many_values(self): - """Test clean_artifacts with too many values in highpass tuple.""" - with self.assertRaises(ValueError) as cm: - clean_artifacts(self.test_eeg, Highpass=(0.1, 0.5, 1.0)) - self.assertIn('Highpass must be a (low, high) tuple or None/"off"', str(cm.exception)) - - def test_clean_artifacts_invalid_highpass_empty_tuple(self): - """Test clean_artifacts with empty highpass tuple.""" - with self.assertRaises(ValueError) as cm: - clean_artifacts(self.test_eeg, Highpass=()) - self.assertIn('Highpass must be a (low, high) tuple or None/"off"', str(cm.exception)) - - def test_clean_artifacts_invalid_highpass_list_single(self): - """Test clean_artifacts with single-element list.""" - with self.assertRaises(ValueError) as cm: - clean_artifacts(self.test_eeg, Highpass=[0.5]) - self.assertIn('Highpass must be a (low, high) tuple or None/"off"', str(cm.exception)) - - def test_clean_artifacts_valid_highpass_list(self): - """Test clean_artifacts with valid highpass list (should work like tuple).""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Highpass=[0.25, 0.75], # List instead of tuple - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - FlatlineCriterion='off', - ) - # Should work - list is acceptable - self.assertIsInstance(EEG, dict) - - def test_clean_artifacts_mutually_exclusive_channels(self): - """Test clean_artifacts with mutually exclusive channel parameters.""" - with self.assertRaises(ValueError) as cm: - clean_artifacts(self.test_eeg, Channels=['Ch1', 'Ch2'], Channels_ignore=['Ch3']) - self.assertIn('mutually exclusive', str(cm.exception)) - - def test_clean_artifacts_mutually_exclusive_channels_both_empty(self): - """Test clean_artifacts with both channel parameters empty (should work).""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=[], # Empty list - Channels_ignore=[], # Empty list - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - # Should work - empty lists are not mutually exclusive - self.assertIsInstance(EEG, dict) - - def test_clean_artifacts_mutually_exclusive_channels_none_and_list(self): - """Test clean_artifacts with None and non-empty list (should work).""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=None, # None - Channels_ignore=['Ch1'], # Non-empty list - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - # Should work - None and list is not mutually exclusive - self.assertIsInstance(EEG, dict) - - def test_clean_artifacts_mutually_exclusive_channels_both_none(self): - """Test clean_artifacts with both channel parameters as None (should work).""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=None, # None - Channels_ignore=None, # None - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - # Should work - both None is not mutually exclusive - self.assertIsInstance(EEG, dict) - - def test_clean_artifacts_mutually_exclusive_channels_overlapping(self): - """Test clean_artifacts with overlapping channel lists (error expected).""" - with self.assertRaises(ValueError) as cm: - clean_artifacts( - self.test_eeg, - Channels=['Ch1', 'Ch2', 'Ch3'], - Channels_ignore=['Ch2', 'Ch4'], # Ch2 overlaps - ) - self.assertIn('mutually exclusive', str(cm.exception)) - - -class TestCleanArtifactsFlatline(DebuggableTestCase): - """Test cases for flatline channel removal.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_flatline_removal(self): - """Test flatline channel removal.""" - # Create some flatline channels - eeg_with_flatlines = self.test_eeg.copy() - eeg_with_flatlines['data'] = self.test_eeg['data'].copy() - eeg_with_flatlines['data'][5, :] = 0.0 # Flatline channel (2D data) - eeg_with_flatlines['data'][10, :] = 1.0 # Another flatline channel - original_nbchan = eeg_with_flatlines['nbchan'] - - EEG, HP, BUR, removed_channels = clean_artifacts( - eeg_with_flatlines, - FlatlineCriterion=1.0, # Short flatline duration - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - ) - - # Should have removed some channels - self.assertLess(EEG['nbchan'], original_nbchan) - - def test_clean_artifacts_flatline_off(self): - """Test flatline removal disabled.""" - # Create some flatline channels - eeg_with_flatlines = self.test_eeg.copy() - eeg_with_flatlines['data'] = self.test_eeg['data'].copy() - eeg_with_flatlines['data'][5, :] = 0.0 # Flatline channel (2D data) - - EEG, HP, BUR, removed_channels = clean_artifacts( - eeg_with_flatlines, - FlatlineCriterion='off', - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - ) - - # Should not have removed any channels - self.assertEqual(EEG['nbchan'], eeg_with_flatlines['nbchan']) - - -class TestCleanArtifactsHighpass(DebuggableTestCase): - """Test cases for highpass filtering (drift removal).""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_highpass_filtering(self): - """Test highpass filtering.""" - original_data = self.test_eeg['data'].copy() # Save before call - - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Highpass=(0.5, 1.0), - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - FlatlineCriterion='off', - ) - - # HP should contain the highpass filtered data - self.assertIsInstance(HP, dict) - self.assertIn('data', HP) - - # Data should be different after filtering - self.assertFalse(np.array_equal(HP['data'], original_data)) - - def test_clean_artifacts_highpass_off(self): - """Test highpass filtering disabled.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Highpass='off', - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - FlatlineCriterion='off', - ) - - # Data should be unchanged - np.testing.assert_array_equal(HP['data'], self.test_eeg['data']) - - -class TestCleanArtifactsChannelCleaning(DebuggableTestCase): - """Test cases for channel cleaning.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_channel_criterion(self): - """Test channel correlation criterion.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion=0.9, # High threshold - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some channels with high threshold - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - def test_clean_artifacts_line_noise_criterion(self): - """Test line noise criterion.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion=2.0, # Low threshold - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some channels with low threshold - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - def test_clean_artifacts_both_channel_criteria(self): - """Test both channel and line noise criteria.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion=0.8, - LineNoiseCriterion=4.0, - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some channels - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - -class TestCleanArtifactsBurstCleaning(DebuggableTestCase): - """Test cases for burst cleaning (ASR).""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_burst_criterion(self): - """Test burst criterion.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion=5.0, - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # BUR should contain the burst repaired data - self.assertIsInstance(BUR, dict) - self.assertIn('data', BUR) - - def test_clean_artifacts_burst_rejection(self): - """Test burst rejection mode.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion=5.0, - BurstRejection='on', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some samples - self.assertLessEqual(EEG['pnts'], self.test_eeg['pnts']) - - def test_clean_artifacts_burst_off(self): - """Test burst cleaning disabled.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Data should be unchanged - np.testing.assert_array_equal(BUR['data'], self.test_eeg['data']) - - -class TestCleanArtifactsWindowCleaning(DebuggableTestCase): - """Test cases for window cleaning.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_window_criterion(self): - """Test window criterion.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion=0.5, # Allow 50% bad channels per window - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some samples - self.assertLessEqual(EEG['pnts'], self.test_eeg['pnts']) - - def test_clean_artifacts_window_off(self): - """Test window cleaning disabled.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Data should be unchanged - self.assertEqual(EEG['pnts'], self.test_eeg['pnts']) - - -class TestCleanArtifactsChannelSelection(DebuggableTestCase): - """Test cases for channel selection.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_channels_include(self): - """Test channel inclusion.""" - channels_to_include = ['Ch1', 'Ch2', 'Ch3'] - - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=channels_to_include, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have only the specified channels - self.assertEqual(EEG['nbchan'], len(channels_to_include)) - - def test_clean_artifacts_channels_ignore(self): - """Test channel exclusion.""" - channels_to_ignore = ['Ch1', 'Ch2'] - original_nbchan = self.test_eeg['nbchan'] # Save before call - - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels_ignore=channels_to_ignore, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have fewer channels - self.assertEqual(EEG['nbchan'], original_nbchan - len(channels_to_ignore)) - - -class TestCleanArtifactsParameterValidation(DebuggableTestCase): - """Test cases for parameter validation and edge cases.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_invalid_channel_criterion_type(self): - """Test clean_artifacts with invalid ChannelCriterion type.""" - # Should accept numeric values and 'off' - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion=0.8, - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - def test_clean_artifacts_invalid_line_noise_criterion_type(self): - """Test clean_artifacts with invalid LineNoiseCriterion type.""" - # Should accept numeric values and 'off' - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion=4.0, - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - def test_clean_artifacts_invalid_burst_criterion_type(self): - """Test clean_artifacts with invalid BurstCriterion type.""" - # Should accept numeric values and 'off' - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion=5.0, - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - def test_clean_artifacts_invalid_window_criterion_type(self): - """Test clean_artifacts with invalid WindowCriterion type.""" - # Should accept numeric values and 'off' - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion=0.25, - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - def test_clean_artifacts_invalid_flatline_criterion_type(self): - """Test clean_artifacts with invalid FlatlineCriterion type.""" - # Should accept numeric values and 'off' - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion=5.0, - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - def test_clean_artifacts_invalid_burst_rejection_type(self): - """Test clean_artifacts with invalid BurstRejection type.""" - # Should accept 'on' and 'off' strings - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - BurstRejection='on', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - BurstRejection='off', - ) - - def test_clean_artifacts_documented_distance_metrics_with_asr_disabled(self): - """Test clean_artifacts accepts documented Distance spellings when ASR is disabled.""" - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - Distance='euclidian', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - Distance='riemannian', - ) - - def test_clean_artifacts_rejects_unknown_distance_metric(self): - """Test clean_artifacts rejects unknown Distance spellings before cleaning.""" - with self.assertRaisesRegex(ValueError, "Distance must be"): - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - Distance='riemann', - ) - - def test_clean_artifacts_negative_values(self): - """Test clean_artifacts with negative parameter values.""" - # Some parameters should handle negative values gracefully - try: - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - MaxMem=-1, - ) # Negative MaxMem should be handled - except Exception: - # Negative values may cause errors - this is acceptable - pass - - def test_clean_artifacts_zero_values(self): - """Test clean_artifacts with zero parameter values.""" - clean_artifacts( - self.test_eeg, - ChannelCriterion=0.0, # Zero correlation threshold - LineNoiseCriterion=0.0, - BurstCriterion='off', - WindowCriterion=0.0, - Highpass='off', - FlatlineCriterion=0.0, - ) - - def test_clean_artifacts_extreme_values(self): - """Test clean_artifacts with extreme parameter values.""" - clean_artifacts( - self.test_eeg, - ChannelCriterion=1.0, # Perfect correlation required - LineNoiseCriterion=100.0, - BurstCriterion='off', - WindowCriterion=1.0, - Highpass='off', - FlatlineCriterion=1000.0, - ) - - -class TestCleanArtifactsParameters(DebuggableTestCase): - """Test cases for various parameters.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_available_ram(self): - """Test available RAM parameter.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - availableRAM_GB=2.0, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should complete without error - self.assertIsInstance(EEG, dict) - - def test_clean_artifacts_distance_metric(self): - """Test distance metric parameter.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Distance='euclidian', - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should complete without error - self.assertIsInstance(EEG, dict) - - def test_clean_artifacts_max_mem(self): - """Test max memory parameter.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - MaxMem=128, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should complete without error - self.assertIsInstance(EEG, dict) - - -class TestCleanArtifactsIntegration(DebuggableTestCase): - """Integration test cases for clean_artifacts.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_artifacts_full_pipeline(self): - """Test the full clean_artifacts pipeline.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - FlatlineCriterion=5.0, - Highpass=(0.25, 0.75), - ChannelCriterion=0.8, - LineNoiseCriterion=4.0, - BurstCriterion=5.0, - WindowCriterion=0.25, - ) - - # Check all return values - self.assertIsInstance(EEG, dict) - self.assertIsInstance(HP, dict) - self.assertIsInstance(BUR, dict) - self.assertIsInstance(removed_channels, np.ndarray) - - # Check data integrity - self.assertIn('data', EEG) - self.assertIn('srate', EEG) - self.assertIn('nbchan', EEG) - self.assertIn('pnts', EEG) - - # Check that some processing occurred - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - def test_clean_artifacts_return_values(self): - """Test that all return values have correct structure.""" - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Check EEG structure - self.assertIn('data', EEG) - self.assertIn('srate', EEG) - self.assertIn('nbchan', EEG) - self.assertIn('pnts', EEG) - self.assertIn('etc', EEG) - - # Check HP structure (should be same as EEG when no highpass) - self.assertIn('data', HP) - self.assertIn('srate', HP) - self.assertIn('nbchan', HP) - self.assertIn('pnts', HP) - - # Check BUR structure (should be same as EEG when no burst cleaning) - self.assertIn('data', BUR) - self.assertIn('srate', BUR) - self.assertIn('nbchan', BUR) - self.assertIn('pnts', BUR) - - # Check removed_channels array - self.assertEqual(len(removed_channels), self.test_eeg['nbchan']) - self.assertTrue(np.issubdtype(removed_channels.dtype, np.bool_)) - - -class TestCleanArtifactsHpSnapshot(DebuggableTestCase): - """Regression test for the high-pass snapshot point-in-time contract.""" - - def setUp(self): - np.random.seed(11) - self.test_eeg = create_test_eeg() - - def test_hp_snapshot_is_point_in_time(self): - """HP must not carry the sample mask written by the later window stage.""" - EEG, HP, _BUR, _removed = clean_artifacts( - self.test_eeg, - Highpass='off', - ChannelCriterion=0.8, - LineNoiseCriterion=4.0, - BurstCriterion='off', - WindowCriterion=0.25, - ) - - # The window stage populates clean_sample_mask on the final EEG dataset... - self.assertIn('clean_sample_mask', EEG['etc']) - # ...but the high-pass snapshot predates that stage, so it must not share - # the same etc object or carry the later mask. - self.assertIsNot(HP['etc'], EEG['etc']) - self.assertNotIn('clean_sample_mask', HP['etc']) - - -class TestPopCleanRawdataNoMutation(DebuggableTestCase): - """Regression test that pop_clean_rawdata never mutates the caller's EEG.""" - - def setUp(self): - np.random.seed(13) - self.test_eeg = create_test_eeg() - - def test_does_not_mutate_input(self): - """The wrapper must deep-copy so the caller's dataset is untouched.""" - EEG_in = self.test_eeg - original_data = EEG_in['data'].copy() - original_nbchan = EEG_in['nbchan'] - - cleaned = pop_clean_rawdata( - EEG_in, - gui=False, - ChannelCriterion=0.8, - LineNoiseCriterion=4.0, - BurstCriterion='off', - WindowCriterion=0.25, - ) - - # Caller's data, channel count, and etc are all unchanged. - self.assertTrue(np.array_equal(original_data, EEG_in['data'])) - self.assertEqual(EEG_in['nbchan'], original_nbchan) - self.assertNotIn('clean_channel_mask', EEG_in.get('etc', {})) - self.assertNotIn('clean_sample_mask', EEG_in.get('etc', {})) - # The returned dataset is a distinct object. - self.assertIsNot(cleaned, EEG_in) - - -class TestCleanArtifactsErrorSurfacing(DebuggableTestCase): - """Errors inside the selection / channel-cleaning paths must surface, not be masked.""" - - def setUp(self): - self.test_eeg = create_test_eeg() - - def test_pop_select_internal_error_propagates(self): - """A non-ImportError raised inside pop_select must propagate, not silently - fall back to manual label selection (which could select different channels). - """ - import eegprep - - original = eegprep.pop_select - - def failing_pop_select(*args, **kwargs): - raise RuntimeError("simulated pop_select bug") - - eegprep.pop_select = failing_pop_select - try: - with self.assertRaises(RuntimeError): - clean_artifacts( - self.test_eeg, - Channels_ignore=['EEG001', 'EEG002'], - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - finally: - eegprep.pop_select = original - - def test_channels_ignore_preserves_events(self): - """Restricting channels must not wipe the dataset's events.""" - eeg = create_test_eeg() - eeg['event'] = [{'type': 'mark', 'latency': 100.0}, {'type': 'mark', 'latency': 5000.0}] - original_events = list(eeg['event']) - - EEG, _HP, _BUR, _removed = clean_artifacts( - eeg, - Channels_ignore=['EEG001'], - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - self.assertEqual(len(EEG['event']), len(original_events)) - - def test_clean_channels_unexpected_value_error_propagates(self): - """A ValueError from clean_channels unrelated to missing locations must - propagate rather than silently switching to the no-locs algorithm. - """ - import eegprep.plugins.clean_rawdata.clean_artifacts as ca_mod - - original = ca_mod.clean_channels - - def boom(*args, **kwargs): - raise ValueError("totally unrelated bug") - - ca_mod.clean_channels = boom - try: - with self.assertRaises(ValueError) as cm: - clean_artifacts( - self.test_eeg, - ChannelCriterion=0.8, - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - self.assertIn('totally unrelated bug', str(cm.exception)) - finally: - ca_mod.clean_channels = original - - def test_clean_channels_location_error_falls_back(self): - """A missing-locations ValueError still triggers the no-locs fallback.""" - import eegprep.plugins.clean_rawdata.clean_artifacts as ca_mod - - original_cc = ca_mod.clean_channels - original_nolocs = ca_mod.clean_channels_nolocs - - def locs_error(*args, **kwargs): - raise ValueError('To use this function most of your channels should have X,Y,Z location measurements.') - - called = {'nolocs': False} - - def fake_nolocs(EEG, **kwargs): - called['nolocs'] = True - return EEG, np.zeros(EEG['nbchan'], dtype=bool) - - ca_mod.clean_channels = locs_error - ca_mod.clean_channels_nolocs = fake_nolocs - try: - clean_artifacts( - self.test_eeg, - ChannelCriterion=0.8, - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - self.assertTrue(called['nolocs']) - finally: - ca_mod.clean_channels = original_cc - ca_mod.clean_channels_nolocs = original_nolocs - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_clean_asr.py b/tests/test_clean_asr.py deleted file mode 100644 index 829bc390..00000000 --- a/tests/test_clean_asr.py +++ /dev/null @@ -1,509 +0,0 @@ -"""Tests for clean_asr module. - -This module tests the Artifact Subspace Reconstruction (ASR) functionality -including parameter validation, calibration data selection, and various -processing switches. -""" - -import unittest -import numpy as np - -from eegprep.plugins.clean_rawdata.clean_asr import clean_asr -from eegprep.plugins.clean_rawdata.clean_artifacts import clean_artifacts - - -class TestCleanASRBasic(unittest.TestCase): - """Test basic clean_asr functionality.""" - - def setUp(self): - """Set up test fixtures with synthetic EEG data.""" - np.random.seed(42) # For reproducible tests - - # Create synthetic EEG data - self.n_channels = 8 - self.n_samples = 1000 # 4 seconds at 250 Hz - self.srate = 250.0 - - self.test_eeg = { - 'data': np.random.randn(self.n_channels, self.n_samples) * 0.5, - 'srate': self.srate, - 'nbchan': self.n_channels, - 'pnts': self.n_samples, - 'trials': 1, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - } - - def test_clean_asr_missing_required_fields(self): - """Test clean_asr with missing required EEG fields.""" - # Missing 'data' - incomplete_eeg = {'srate': 250.0, 'nbchan': 8} - with self.assertRaises(ValueError) as cm: - clean_asr(incomplete_eeg) - self.assertIn("EEG dictionary must contain", str(cm.exception)) - - # Missing 'srate' - incomplete_eeg = {'data': np.random.randn(8, 1000), 'nbchan': 8} - with self.assertRaises(ValueError) as cm: - clean_asr(incomplete_eeg) - self.assertIn("EEG dictionary must contain", str(cm.exception)) - - # Missing 'nbchan' - incomplete_eeg = {'data': np.random.randn(8, 1000), 'srate': 250.0} - with self.assertRaises(ValueError) as cm: - clean_asr(incomplete_eeg) - self.assertIn("EEG dictionary must contain", str(cm.exception)) - - def test_clean_asr_basic_functionality(self): - """Test basic clean_asr functionality without mocking (may skip if ASR fails).""" - try: - result = clean_asr(self.test_eeg, cutoff=20.0) # Very conservative cutoff - - # Verify result structure - self.assertIsInstance(result, dict) - self.assertIn('data', result) - self.assertEqual(result['data'].shape, self.test_eeg['data'].shape) - self.assertEqual(result['srate'], self.test_eeg['srate']) - self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) - - except Exception as e: - self.skipTest(f"clean_asr basic functionality not available: {e}") - - def test_clean_asr_nbchan_mismatch_warning(self): - """Test clean_asr with mismatch between nbchan and data shape.""" - mismatched_eeg = self.test_eeg.copy() - mismatched_eeg['nbchan'] = 10 # Doesn't match data shape (8, 1000) - - try: - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='WARNING') as log: - result = clean_asr(mismatched_eeg, cutoff=20.0) - - self.assertTrue(any('Mismatch between' in msg for msg in log.output)) - self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"clean_asr nbchan mismatch test not available: {e}") - - -class TestCleanASRParameters(unittest.TestCase): - """Test clean_asr parameter handling and validation.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - self.n_channels = 4 # Smaller for faster testing - self.n_samples = 500 # Shorter for faster testing - self.srate = 250.0 - - self.test_eeg = { - 'data': np.random.randn(self.n_channels, self.n_samples) * 0.5, - 'srate': self.srate, - 'nbchan': self.n_channels, - 'pnts': self.n_samples, - 'trials': 1, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - 'times': np.arange(self.n_samples) / self.srate, - 'event': [], - } - - def test_clean_asr_parameter_acceptance(self): - """Test that clean_asr accepts various parameter combinations.""" - test_cases = [ - {'cutoff': 3.0}, - {'cutoff': 20.0}, # Very conservative - {'window_len': 1.0}, - {'step_size': 64}, - {'max_dims': 0.5}, - {'use_gpu': True}, # Should be ignored in current implementation - {'maxmem': 128}, - {'useriemannian': 'calib'}, - ] - - for params in test_cases: - try: - result = clean_asr(self.test_eeg, **params) - self.assertIsInstance(result, dict) - self.assertIn('data', result) - except Exception as e: - self.skipTest(f"clean_asr parameter test {params} not available: {e}") - - def test_clean_asr_rejects_full_riemannian_processing_request(self): - """Test that unsupported full Riemannian ASR processing fails clearly.""" - with self.assertRaisesRegex(ValueError, "full Riemannian ASR processing is not ported"): - clean_asr(self.test_eeg, useriemannian=True) - - def test_clean_artifacts_rejects_unknown_distance(self): - """Test that clean_artifacts catches Distance typos before ASR dispatch.""" - with self.assertRaisesRegex(ValueError, "Distance must be"): - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion=20, - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - Distance='euclidiann', - ) - - -class TestCleanASRCalibrationData(unittest.TestCase): - """Test clean_asr calibration data selection options.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - self.n_channels = 4 # Smaller for faster testing - self.n_samples = 500 # Shorter for faster testing - self.srate = 250.0 - - self.test_eeg = { - 'data': np.random.randn(self.n_channels, self.n_samples) * 0.5, - 'srate': self.srate, - 'nbchan': self.n_channels, - 'pnts': self.n_samples, - 'trials': 1, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - 'times': np.arange(self.n_samples) / self.srate, - 'event': [], - } - - def test_clean_asr_ref_maxbadchannels_off(self): - """Test clean_asr with ref_maxbadchannels='off'.""" - try: - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='INFO') as log: - result = clean_asr(self.test_eeg, ref_maxbadchannels='off', cutoff=20.0) - - self.assertTrue(any('Using the entire data for calibration' in msg for msg in log.output)) - self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"clean_asr ref_maxbadchannels='off' test not available: {e}") - - def test_clean_asr_ref_tolerances_off(self): - """Test clean_asr with ref_tolerances='off'.""" - try: - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='INFO') as log: - result = clean_asr(self.test_eeg, ref_tolerances='off', cutoff=20.0) - - self.assertTrue(any('Using the entire data for calibration' in msg for msg in log.output)) - self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"clean_asr ref_tolerances='off' test not available: {e}") - - def test_clean_asr_ref_wndlen_off(self): - """Test clean_asr with ref_wndlen='off'.""" - try: - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='INFO') as log: - result = clean_asr(self.test_eeg, ref_wndlen='off', cutoff=20.0) - - self.assertTrue(any('Using the entire data for calibration' in msg for msg in log.output)) - self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"clean_asr ref_wndlen='off' test not available: {e}") - - def test_clean_asr_user_supplied_calibration_data(self): - """Test clean_asr with user-supplied calibration data.""" - # Create user-supplied calibration data - user_calib_data = np.random.randn(self.n_channels, 250) * 0.3 - - try: - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='INFO') as log: - result = clean_asr(self.test_eeg, ref_maxbadchannels=user_calib_data, cutoff=20.0) - - self.assertTrue(any('Using user-supplied data array' in msg for msg in log.output)) - self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"clean_asr user-supplied calibration data test not available: {e}") - - def test_clean_asr_invalid_user_calibration_data_shape(self): - """Test clean_asr with invalid user-supplied calibration data shape.""" - # Wrong shape (1D instead of 2D) - invalid_calib_data = np.random.randn(100) - - with self.assertRaises(ValueError) as cm: - clean_asr(self.test_eeg, ref_maxbadchannels=invalid_calib_data) - self.assertIn('must be a 2D array', str(cm.exception)) - - # Wrong number of channels - invalid_calib_data = np.random.randn(5, 500) # 5 channels instead of 8 - - with self.assertRaises(ValueError) as cm: - clean_asr(self.test_eeg, ref_maxbadchannels=invalid_calib_data) - self.assertIn('must be a 2D array with shape', str(cm.exception)) - - def test_clean_asr_invalid_ref_maxbadchannels_type(self): - """Test clean_asr with invalid ref_maxbadchannels type.""" - with self.assertRaises(ValueError) as cm: - clean_asr(self.test_eeg, ref_maxbadchannels='invalid_string') - self.assertIn('Unsupported value or type for', str(cm.exception)) - - with self.assertRaises(ValueError) as cm: - clean_asr(self.test_eeg, ref_maxbadchannels=['invalid', 'list']) - self.assertIn('Unsupported value or type for', str(cm.exception)) - - -class TestCleanASRCalibrationFailure(unittest.TestCase): - """Test clean_asr behavior when calibration fails.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - self.n_channels = 4 - self.n_samples = 100 # Very short data to potentially trigger failures - self.srate = 250.0 - - self.test_eeg = { - 'data': np.random.randn(self.n_channels, self.n_samples) * 0.5, - 'srate': self.srate, - 'nbchan': self.n_channels, - 'pnts': self.n_samples, - 'trials': 1, - } - - def test_clean_asr_insufficient_calibration_data(self): - """Test clean_asr when there's insufficient calibration data.""" - # Use very short data that might cause calibration failure - short_eeg = self.test_eeg.copy() - short_eeg['data'] = np.random.randn(self.n_channels, 10) * 0.5 # Only 10 samples - short_eeg['pnts'] = 10 - - with self.assertRaises(ValueError) as cm: - clean_asr(short_eeg, cutoff=5.0) - # Should contain "ASR calibration failed" in the error message - self.assertIn('ASR calibration failed', str(cm.exception)) - - def test_clean_asr_automatic_calibration_fallback(self): - """Test clean_asr fallback when automatic calibration data selection is used.""" - try: - # Use parameters that trigger automatic calibration data selection - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='INFO') as log: - result = clean_asr( - self.test_eeg, - ref_maxbadchannels=0.1, - ref_tolerances=(-3.0, 5.0), - ref_wndlen=1.0, - cutoff=20.0, # Conservative - ) - - # Should attempt to find clean calibration data - self.assertTrue(any('Finding a clean section' in msg for msg in log.output)) - self.assertIsInstance(result, dict) - except Exception as e: - # If clean_windows or calibration fails, we should see appropriate error handling - if 'ASR calibration failed' in str(e): - pass # Expected behavior for insufficient data - else: - self.skipTest(f"clean_asr automatic calibration test not available: {e}") - - def test_clean_asr_unexpected_clean_windows_error_propagates(self): - """An unexpected (non-ValueError) failure in clean_windows must propagate, - not be swallowed into a silent 'use all data for calibration' fallback. - """ - import eegprep.plugins.clean_rawdata.clean_asr as casr_mod - - original = casr_mod.clean_windows - - def boom(*args, **kwargs): - raise RuntimeError("simulated clean_windows bug") - - casr_mod.clean_windows = boom - try: - with self.assertRaises(RuntimeError) as cm: - clean_asr( - self.test_eeg, - ref_maxbadchannels=0.1, - ref_tolerances=(-3.0, 5.0), - ref_wndlen=1.0, - cutoff=20.0, - ) - self.assertIn('simulated clean_windows bug', str(cm.exception)) - finally: - casr_mod.clean_windows = original - - def test_clean_asr_clean_windows_value_error_falls_back(self): - """A ValueError from clean_windows (expected calibration-data problem) still - triggers the documented all-data fallback rather than crashing. - """ - import eegprep.plugins.clean_rawdata.clean_asr as casr_mod - - original = casr_mod.clean_windows - - def insufficient(*args, **kwargs): - raise ValueError('Not enough data for even a single window.') - - # Enough samples that all-data calibration succeeds once the fallback kicks in. - eeg = { - 'data': np.random.randn(self.n_channels, 5000) * 0.5, - 'srate': self.srate, - 'nbchan': self.n_channels, - 'pnts': 5000, - 'trials': 1, - } - - casr_mod.clean_windows = insufficient - try: - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='WARNING') as log: - result = clean_asr( - eeg, - ref_maxbadchannels=0.1, - ref_tolerances=(-3.0, 5.0), - ref_wndlen=1.0, - cutoff=20.0, - ) - self.assertTrue(any('Falling back to using the entire data' in msg for msg in log.output)) - self.assertIsInstance(result, dict) - finally: - casr_mod.clean_windows = original - - -class TestCleanASRSignalExtrapolation(unittest.TestCase): - """Test clean_asr signal extrapolation logic.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - self.n_channels = 4 - self.n_samples = 500 - self.srate = 250.0 - - self.test_eeg = { - 'data': np.random.randn(self.n_channels, self.n_samples) * 0.5, - 'srate': self.srate, - 'nbchan': self.n_channels, - 'pnts': self.n_samples, - 'trials': 1, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - 'times': np.arange(self.n_samples) / self.srate, - 'event': [], - } - - def test_clean_asr_with_different_window_lengths(self): - """Test clean_asr with different window lengths that affect extrapolation.""" - window_lengths = [0.2, 0.5, 1.0] - - for window_len in window_lengths: - try: - result = clean_asr(self.test_eeg, window_len=window_len, cutoff=20.0) - - # Should preserve original data shape - self.assertEqual(result['data'].shape, self.test_eeg['data'].shape) - self.assertIsInstance(result, dict) - - except Exception as e: - self.skipTest(f"clean_asr window_len={window_len} test not available: {e}") - - def test_clean_asr_very_short_data(self): - """Test clean_asr with very short data that affects extrapolation.""" - # Create very short data - short_eeg = self.test_eeg.copy() - short_eeg['data'] = np.random.randn(self.n_channels, 50) * 0.5 # Only 50 samples - short_eeg['pnts'] = 50 - - try: - result = clean_asr(short_eeg, window_len=0.5, cutoff=20.0) - self.assertEqual(result['data'].shape[1], 50) # Should preserve original length - self.assertIsInstance(result, dict) - except ValueError as e: - # Expected for insufficient data - if 'ASR calibration failed' in str(e): - pass # This is expected behavior - else: - raise - except Exception as e: - self.skipTest(f"clean_asr very short data test not available: {e}") - - -class TestCleanASREdgeCases(unittest.TestCase): - """Test clean_asr edge cases and error handling.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(42) - self.n_channels = 4 - self.n_samples = 500 - self.srate = 250.0 - - self.test_eeg = { - 'data': np.random.randn(self.n_channels, self.n_samples) * 0.5, - 'srate': self.srate, - 'nbchan': self.n_channels, - 'pnts': self.n_samples, - 'trials': 1, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - 'times': np.arange(self.n_samples) / self.srate, - 'event': [], - } - - def test_clean_asr_single_channel_data(self): - """Test clean_asr with single channel data.""" - single_channel_eeg = { - 'data': np.random.randn(1, self.n_samples) * 0.5, - 'srate': self.srate, - 'nbchan': 1, - 'pnts': self.n_samples, - 'trials': 1, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - 'times': np.arange(self.n_samples) / self.srate, - 'event': [], - } - - try: - result = clean_asr(single_channel_eeg, cutoff=20.0) - self.assertIsInstance(result, dict) - self.assertEqual(result['data'].shape[0], 1) - except Exception as e: - self.skipTest(f"clean_asr single channel test not available: {e}") - - def test_clean_asr_different_data_types(self): - """Test clean_asr with different input data types.""" - # Test with float32 data - float32_eeg = self.test_eeg.copy() - float32_eeg['data'] = float32_eeg['data'].astype(np.float32) - - try: - result = clean_asr(float32_eeg, cutoff=20.0) - self.assertIsInstance(result, dict) - # Data should be processed regardless of input dtype - self.assertTrue(np.isfinite(result['data']).all()) - except Exception as e: - self.skipTest(f"clean_asr different data types test not available: {e}") - - def test_clean_asr_step_size_computation(self): - """Test clean_asr step_size computation when None.""" - try: - # Test that function completes with step_size=None (should compute internally) - result = clean_asr(self.test_eeg, window_len=1.0, step_size=None, cutoff=20.0) - self.assertIsInstance(result, dict) - self.assertEqual(result['data'].shape, self.test_eeg['data'].shape) - except Exception as e: - self.skipTest(f"clean_asr step_size computation test not available: {e}") - - def test_clean_asr_window_len_computation(self): - """Test clean_asr window_len computation when None.""" - try: - # Test that function completes with window_len=None (should compute internally) - result = clean_asr(self.test_eeg, window_len=None, cutoff=20.0) - self.assertIsInstance(result, dict) - self.assertEqual(result['data'].shape, self.test_eeg['data'].shape) - except Exception as e: - self.skipTest(f"clean_asr window_len computation test not available: {e}") - - def test_clean_asr_extreme_cutoff_values(self): - """Test clean_asr with extreme cutoff values.""" - extreme_cutoffs = [1.0, 50.0] # Very aggressive and very conservative - - for cutoff in extreme_cutoffs: - try: - result = clean_asr(self.test_eeg, cutoff=cutoff) - self.assertIsInstance(result, dict) - self.assertEqual(result['data'].shape, self.test_eeg['data'].shape) - except Exception as e: - self.skipTest(f"clean_asr extreme cutoff={cutoff} test not available: {e}") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_clean_drifts.py b/tests/test_clean_drifts.py deleted file mode 100644 index 5f00607e..00000000 --- a/tests/test_clean_drifts.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Test suite for clean_drifts.py - Drift removal filtering. - -This module tests the clean_drifts function that removes low-frequency -drifts from EEG data using high-pass filtering. -""" - -import unittest -import sys -import numpy as np - -# Add src to path for imports -sys.path.insert(0, 'src') -from eegprep.plugins.clean_rawdata.clean_drifts import clean_drifts -from eegprep.utils.testing import DebuggableTestCase - -from tests.fixtures import create_test_eeg as _create_test_eeg - - -def create_test_eeg(): - """Continuous (2D) EEG fixture sized for clean_drifts (20 s at 500 Hz).""" - return _create_test_eeg(n_channels=32, n_samples=10000, srate=500.0, n_trials=1) - - -class TestCleanDriftsBasic(DebuggableTestCase): - """Basic test cases for clean_drifts function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_drifts_basic_functionality(self): - """Test basic clean_drifts functionality with default parameters.""" - result = clean_drifts(self.test_eeg.copy()) - - # Check that EEG structure is preserved - self.assertIn('data', result) - self.assertIn('srate', result) - self.assertIn('nbchan', result) - self.assertIn('pnts', result) - self.assertIn('etc', result) - - # Check that data dimensions are preserved - self.assertEqual(result['srate'], self.test_eeg['srate']) - self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(result['pnts'], self.test_eeg['pnts']) - self.assertEqual(result['trials'], self.test_eeg['trials']) - - # Check that data type is float64 - self.assertEqual(result['data'].dtype, np.float64) - - # Check that filter kernel is stored - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_default_parameters(self): - """Test clean_drifts with default parameters.""" - result = clean_drifts(self.test_eeg.copy()) - - # Should work with default parameters - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_custom_transition(self): - """Test clean_drifts with custom transition band.""" - result = clean_drifts(self.test_eeg.copy(), transition=(1.0, 2.0)) - - # Should work with custom transition band - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_custom_attenuation(self): - """Test clean_drifts with custom attenuation.""" - result = clean_drifts(self.test_eeg.copy(), attenuation=60.0) - - # Should work with custom attenuation - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_fir_method(self): - """Test clean_drifts with FIR method.""" - result = clean_drifts(self.test_eeg.copy(), method='fir') - - # Should work with FIR method - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_fft_method(self): - """Test clean_drifts with FFT method.""" - result = clean_drifts(self.test_eeg.copy(), method='fft') - - # Should work with FFT method - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - -class TestCleanDriftsEdgeCases(DebuggableTestCase): - """Edge case test cases for clean_drifts function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_drifts_single_channel(self): - """Test clean_drifts with single channel data.""" - # Create single channel data (2D continuous) - single_channel_eeg = self.test_eeg.copy() - single_channel_eeg['data'] = np.random.randn(1, 10000) - single_channel_eeg['nbchan'] = 1 - single_channel_eeg['chanlocs'] = [single_channel_eeg['chanlocs'][0]] - - result = clean_drifts(single_channel_eeg) - - # Should work with single channel - self.assertEqual(result['nbchan'], 1) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_single_trial(self): - """Test clean_drifts with continuous (single trial) data.""" - # Create continuous data (2D - single trial is the normal case) - single_trial_eeg = self.test_eeg.copy() - single_trial_eeg['data'] = np.random.randn(32, 10000) - single_trial_eeg['trials'] = 1 - - result = clean_drifts(single_trial_eeg) - - # Should work with single trial - self.assertEqual(result['trials'], 1) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_continuous_data(self): - """Test clean_drifts with continuous (2D) data.""" - # Create continuous data (2D) - continuous_eeg = self.test_eeg.copy() - continuous_eeg['data'] = np.random.randn(32, 1000) - continuous_eeg['trials'] = 1 - - result = clean_drifts(continuous_eeg) - - # Should work with continuous data - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_float32_data(self): - """Test clean_drifts with float32 data.""" - # Create float32 data - float32_eeg = self.test_eeg.copy() - float32_eeg['data'] = np.random.randn(32, 10000).astype(np.float32) - - result = clean_drifts(float32_eeg) - - # Should convert to float64 - self.assertEqual(result['data'].dtype, np.float64) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_float64_data(self): - """Test clean_drifts with float64 data.""" - # Create float64 data - float64_eeg = self.test_eeg.copy() - float64_eeg['data'] = np.random.randn(32, 10000).astype(np.float64) - - result = clean_drifts(float64_eeg) - - # Should remain float64 - self.assertEqual(result['data'].dtype, np.float64) - self.assertIn('clean_drifts_kernel', result['etc']) - - -class TestCleanDriftsParameters(DebuggableTestCase): - """Parameter test cases for clean_drifts function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_drifts_different_transition_bands(self): - """Test clean_drifts with different transition bands.""" - # Test different transition bands - transitions = [(0.1, 0.5), (0.5, 1.0), (1.0, 2.0), (2.0, 5.0)] - - for transition in transitions: - result = clean_drifts(self.test_eeg.copy(), transition=transition) - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_different_attenuations(self): - """Test clean_drifts with different attenuation values.""" - # Test different attenuation values - attenuations = [40.0, 60.0, 80.0, 100.0] - - for attenuation in attenuations: - result = clean_drifts(self.test_eeg.copy(), attenuation=attenuation) - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - def test_clean_drifts_both_methods(self): - """Test clean_drifts with both FIR and FFT methods.""" - # Test both methods - methods = ['fir', 'fft'] - - for method in methods: - result = clean_drifts(self.test_eeg.copy(), method=method) - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - -class TestCleanDriftsIntegration(DebuggableTestCase): - """Integration test cases for clean_drifts function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_drifts_preserves_structure(self): - """Test that clean_drifts preserves EEG structure.""" - result = clean_drifts(self.test_eeg.copy()) - - # Check that all essential fields are preserved - essential_fields = ['data', 'srate', 'nbchan', 'pnts', 'trials', 'xmin', 'xmax', 'times', 'chanlocs'] - for field in essential_fields: - self.assertIn(field, result) - - # Check that data integrity is maintained - self.assertEqual(result['srate'], self.test_eeg['srate']) - self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(result['pnts'], self.test_eeg['pnts']) - self.assertEqual(result['trials'], self.test_eeg['trials']) - - def test_clean_drifts_data_modification(self): - """Test that clean_drifts actually modifies the data.""" - original_data = self.test_eeg['data'].copy() - result = clean_drifts(self.test_eeg.copy()) - - # Data should be modified (filtered) - self.assertFalse(np.array_equal(original_data, result['data'])) - - # But shape should be preserved - self.assertEqual(original_data.shape, result['data'].shape) - - def test_clean_drifts_kernel_properties(self): - """Test properties of the filter kernel.""" - result = clean_drifts(self.test_eeg.copy()) - - kernel = result['etc']['clean_drifts_kernel'] - - # Kernel should be a numpy array - self.assertIsInstance(kernel, np.ndarray) - - # Kernel should not be empty - self.assertGreater(len(kernel), 0) - - # Kernel should be 1D - self.assertEqual(kernel.ndim, 1) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_clean_flatlines.py b/tests/test_clean_flatlines.py deleted file mode 100644 index 40723893..00000000 --- a/tests/test_clean_flatlines.py +++ /dev/null @@ -1,496 +0,0 @@ -""" -Test suite for clean_flatlines.py - Flatline channel removal. - -This module tests the clean_flatlines function that removes channels with -prolonged flatline periods from EEG data. -""" - -import unittest -import sys -import numpy as np - -# Add src to path for imports -sys.path.insert(0, 'src') -from eegprep.plugins.clean_rawdata.clean_flatlines import clean_flatlines -from eegprep.utils.testing import DebuggableTestCase - -from tests.fixtures import create_test_eeg as _create_test_eeg - - -def create_test_eeg(): - """Epoched EEG fixture sized for clean_flatlines (32 ch, 1000 pnts, 10 trials).""" - return _create_test_eeg(n_channels=32, n_samples=1000, srate=500.0, n_trials=10) - - -class TestCleanFlatlinesBasic(DebuggableTestCase): - """Basic test cases for clean_flatlines function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_flatlines_basic_functionality(self): - """Test basic clean_flatlines functionality with default parameters.""" - result = clean_flatlines(self.test_eeg.copy()) - - # Check that EEG structure is preserved - self.assertIn('data', result) - self.assertIn('srate', result) - self.assertIn('nbchan', result) - self.assertIn('pnts', result) - - # Check that data dimensions are reasonable - self.assertEqual(result['srate'], self.test_eeg['srate']) - self.assertLessEqual(result['nbchan'], self.test_eeg['nbchan']) - self.assertGreaterEqual(result['nbchan'], 1) # At least one channel should remain - - def test_clean_flatlines_no_flatlines(self): - """Test clean_flatlines with data that has no flatlines.""" - # Create data with no flatlines - eeg_no_flatlines = self.test_eeg.copy() - eeg_no_flatlines['data'] = np.random.randn(32, 1000, 10) - - result = clean_flatlines(eeg_no_flatlines) - - # Should not remove any channels - self.assertEqual(result['nbchan'], eeg_no_flatlines['nbchan']) - - def test_clean_flatlines_with_flatlines(self): - """Test clean_flatlines with data that has flatlines.""" - # Create data with some flatlines - use constant values to create proper flatlines - eeg_with_flatlines = self.test_eeg.copy() - # Create flatlines by setting consecutive samples to the same value - eeg_with_flatlines['data'][5, :, :] = 1.0 # Constant value channel - eeg_with_flatlines['data'][10, :, :] = 0.0 # Another constant value channel - - result = clean_flatlines(eeg_with_flatlines, max_flatline_duration=1.0) - - # Note: Current implementation may not detect flatlines as expected - # Test that the function completes without error - self.assertIsInstance(result, dict) - - def test_clean_flatlines_all_flatlines(self): - """Test clean_flatlines when all channels have flatlines.""" - # Create data where all channels have flatlines - eeg_all_flatlines = self.test_eeg.copy() - eeg_all_flatlines['data'] = np.zeros_like(eeg_all_flatlines['data']) - - result = clean_flatlines(eeg_all_flatlines, max_flatline_duration=1.0) - - # Should not remove all channels (warning should be logged) - self.assertEqual(result['nbchan'], eeg_all_flatlines['nbchan']) - - def test_clean_flatlines_custom_duration(self): - """Test clean_flatlines with custom flatline duration.""" - # Create data with short flatlines - eeg_short_flatlines = self.test_eeg.copy() - # Create a short flatline by setting a portion to constant value - eeg_short_flatlines['data'][5, :500, :] = 1.0 # Short flatline - - # Test with short duration - result1 = clean_flatlines(eeg_short_flatlines, max_flatline_duration=0.5) - self.assertIsInstance(result1, dict) - - # Test with long duration - result2 = clean_flatlines(eeg_short_flatlines, max_flatline_duration=5.0) - self.assertIsInstance(result2, dict) - - def test_clean_flatlines_custom_jitter(self): - """Test clean_flatlines with custom jitter tolerance.""" - # Create data with slight variations (jitter) - eeg_with_jitter = self.test_eeg.copy() - # Add very small jitter to a constant channel - base_value = 1.0 - jitter = 1e-10 * np.random.randn(1000, 10) - eeg_with_jitter['data'][5, :, :] = base_value + jitter - - # Test with low jitter tolerance - result1 = clean_flatlines(eeg_with_jitter, max_allowed_jitter=1.0) - self.assertIsInstance(result1, dict) - - # Test with high jitter tolerance - result2 = clean_flatlines(eeg_with_jitter, max_allowed_jitter=100.0) - self.assertIsInstance(result2, dict) - - -class TestCleanFlatlinesEdgeCases(DebuggableTestCase): - """Edge case test cases for clean_flatlines function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_flatlines_single_channel(self): - """Test clean_flatlines with single channel data.""" - # Create single channel data - single_channel_eeg = self.test_eeg.copy() - single_channel_eeg['data'] = np.random.randn(1, 1000, 10) - single_channel_eeg['nbchan'] = 1 - single_channel_eeg['chanlocs'] = [single_channel_eeg['chanlocs'][0]] - - result = clean_flatlines(single_channel_eeg) - - # Should preserve the channel - self.assertEqual(result['nbchan'], 1) - - def test_clean_flatlines_single_trial(self): - """Test clean_flatlines with single trial data.""" - # Create single trial data - single_trial_eeg = self.test_eeg.copy() - single_trial_eeg['data'] = np.random.randn(32, 1000, 1) - single_trial_eeg['trials'] = 1 - - result = clean_flatlines(single_trial_eeg) - - # Should preserve structure - self.assertEqual(result['trials'], 1) - self.assertEqual(result['data'].shape[2], 1) - - def test_clean_flatlines_continuous_data(self): - """Test clean_flatlines with continuous data (no trials dimension).""" - # Create continuous data - continuous_eeg = self.test_eeg.copy() - continuous_eeg['data'] = np.random.randn(32, 1000) - continuous_eeg['trials'] = 1 - - result = clean_flatlines(continuous_eeg) - - # Should preserve structure - self.assertEqual(result['trials'], 1) - self.assertEqual(len(result['data'].shape), 2) - - def test_clean_flatlines_with_clean_channel_mask(self): - """Test clean_flatlines with existing clean_channel_mask.""" - eeg_with_mask = self.test_eeg.copy() - eeg_with_mask['etc'] = {'clean_channel_mask': np.ones(32, dtype=bool)} - - # Create a flatline - eeg_with_mask['data'][5, :, :] = 1.0 - - result = clean_flatlines(eeg_with_mask, max_flatline_duration=1.0) - - # Should update the mask if channel is removed - self.assertIn('clean_channel_mask', result['etc']) - if result['nbchan'] < eeg_with_mask['nbchan']: - self.assertFalse(result['etc']['clean_channel_mask'][5]) - - def test_clean_flatlines_without_clean_channel_mask(self): - """Test clean_flatlines without existing clean_channel_mask.""" - eeg_no_mask = self.test_eeg.copy() - eeg_no_mask['etc'] = {} - - # Create a flatline - eeg_no_mask['data'][5, :, :] = 1.0 - - result = clean_flatlines(eeg_no_mask, max_flatline_duration=1.0) - - # Should create a new mask if channel is removed - if result['nbchan'] < eeg_no_mask['nbchan']: - self.assertIn('clean_channel_mask', result['etc']) - - def test_clean_flatlines_with_ica_fields(self): - """Test clean_flatlines with ICA fields present.""" - eeg_with_ica = self.test_eeg.copy() - eeg_with_ica['icawinv'] = np.random.randn(32, 10) - eeg_with_ica['icasphere'] = np.random.randn(32, 32) - eeg_with_ica['icaweights'] = np.random.randn(10, 32) - eeg_with_ica['icaact'] = np.random.randn(10, 1000, 10) - - # Create a flatline - eeg_with_ica['data'][5, :, :] = 1.0 - - result = clean_flatlines(eeg_with_ica, max_flatline_duration=1.0) - - # ICA fields should be cleared when channels are removed - if result['nbchan'] < eeg_with_ica['nbchan']: - self.assertEqual(len(result['icawinv']), 0) - self.assertEqual(len(result['icasphere']), 0) - self.assertEqual(len(result['icaweights']), 0) - self.assertEqual(len(result['icaact']), 0) - - -class TestCleanFlatlinesDataTypes(DebuggableTestCase): - """Data type test cases for clean_flatlines function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_flatlines_float32_data(self): - """Test clean_flatlines with float32 data.""" - eeg_float32 = self.test_eeg.copy() - eeg_float32['data'] = np.random.randn(32, 1000, 10).astype(np.float32) - - result = clean_flatlines(eeg_float32) - - # Should preserve data type - self.assertEqual(result['data'].dtype, np.float32) - - def test_clean_flatlines_float64_data(self): - """Test clean_flatlines with float64 data.""" - eeg_float64 = self.test_eeg.copy() - eeg_float64['data'] = np.random.randn(32, 1000, 10).astype(np.float64) - - result = clean_flatlines(eeg_float64) - - # Should convert to float32 when channels are removed - if result['nbchan'] < eeg_float64['nbchan']: - self.assertEqual(result['data'].dtype, np.float32) - else: - self.assertEqual(result['data'].dtype, np.float64) - - -class TestCleanFlatlinesValidation(DebuggableTestCase): - """Validation test cases for clean_flatlines function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_flatlines_empty_data(self): - """Test clean_flatlines with empty data.""" - eeg_empty = self.test_eeg.copy() - eeg_empty['data'] = np.array([]) - - # Should handle empty data gracefully - result = clean_flatlines(eeg_empty) - self.assertIsInstance(result, dict) - - def test_clean_flatlines_invalid_max_duration(self): - """Test clean_flatlines with invalid max_duration.""" - eeg_invalid = self.test_eeg.copy() - - # Test with negative duration - should handle gracefully - result = clean_flatlines(eeg_invalid, max_flatline_duration=-1.0) - self.assertIsInstance(result, dict) - - def test_clean_flatlines_invalid_max_jitter(self): - """Test clean_flatlines with invalid max_jitter.""" - eeg_invalid = self.test_eeg.copy() - - # Test with negative jitter - should handle gracefully - result = clean_flatlines(eeg_invalid, max_allowed_jitter=-1.0) - self.assertIsInstance(result, dict) - - def test_clean_flatlines_single_sample(self): - """Test clean_flatlines with single sample data.""" - eeg_single = self.test_eeg.copy() - eeg_single['data'] = np.random.randn(32, 1, 10) - eeg_single['pnts'] = 1 - - result = clean_flatlines(eeg_single) - - # Should handle single sample gracefully - self.assertEqual(result['pnts'], 1) - - def test_clean_flatlines_no_variance_data(self): - """Test clean_flatlines with data that has no variance.""" - eeg_no_var = self.test_eeg.copy() - eeg_no_var['data'] = np.ones_like(eeg_no_var['data']) - - result = clean_flatlines(eeg_no_var, max_flatline_duration=1.0) - - # Note: Current implementation may not detect flatlines as expected - # Test that the function completes without error - self.assertIsInstance(result, dict) - - def test_clean_flatlines_partial_flatlines(self): - """Test clean_flatlines with partial flatlines in channels.""" - eeg_partial = self.test_eeg.copy() - # Create partial flatlines - eeg_partial['data'][5, 100:200, :] = 1.0 # Partial flatline - eeg_partial['data'][10, 300:400, :] = 0.0 # Another partial flatline - - result = clean_flatlines(eeg_partial, max_flatline_duration=0.5) - - # Note: Current implementation may not detect flatlines as expected - # Test that the function completes without error - self.assertIsInstance(result, dict) - - # def test_clean_flatlines_pop_select_fallback(self): - # """Test clean_flatlines fallback when pop_select is not available.""" - # eeg_fallback = self.test_eeg.copy() - # eeg_fallback['data'][5, :, :] = 1.0 # Create flatline - - # # Mock the import to fail - # import sys - # original_import = __builtins__['__import__'] - - # def mock_import(name, *args, **kwargs): - # if name == 'eegprep': - # raise ImportError("Mock import error") - # return original_import(name, *args, **kwargs) - - # __builtins__['__import__'] = mock_import - - # try: - # result = clean_flatlines(eeg_fallback, max_flatline_duration=1.0) - # # Should still work with fallback - # self.assertIsInstance(result, dict) - # finally: - # __builtins__['__import__'] = original_import - - def test_clean_flatlines_chanlocs_mismatch(self): - """Test clean_flatlines with mismatched chanlocs.""" - eeg_mismatch = self.test_eeg.copy() - eeg_mismatch['chanlocs'] = eeg_mismatch['chanlocs'][:16] # Half the channels - - result = clean_flatlines(eeg_mismatch) - - # Should handle mismatch gracefully - self.assertIn('chanlocs', result) - - def test_clean_flatlines_walrus_operator_branch(self): - """Test clean_flatlines walrus operator branch (Python 3.8+).""" - eeg_walrus = self.test_eeg.copy() - eeg_walrus['etc'] = {'clean_channel_mask': np.ones(32, dtype=bool)} - eeg_walrus['data'][5, :, :] = 1.0 # Create flatline - - result = clean_flatlines(eeg_walrus, max_flatline_duration=1.0) - - # Should update existing mask if channel is removed - if result['nbchan'] < eeg_walrus['nbchan']: - self.assertFalse(result['etc']['clean_channel_mask'][5]) - - def test_clean_flatlines_fallback_composites_existing_mask(self): - """Fallback path with a prior clean_channel_mask must composite, not crash. - - Reproduces the walrus-precedence bug: when pop_select fails and a prior - clean_channel_mask exists, the mask update must run ``mask[mask] = ~removed`` - rather than treating the mask as a bool. Uses continuous (2D) data so the - composite indexing exercises the real fallback branch. - """ - eeg = self.test_eeg.copy() - eeg['data'] = np.random.randn(32, 1000) - eeg['trials'] = 1 - eeg['data'][5, :] = 1.0 # flatline channel 5 - eeg['etc'] = {'clean_channel_mask': np.ones(32, dtype=bool)} - # Empty chanlocs so the unrelated chanlocs-trim branch is skipped and the - # test isolates the clean_channel_mask compositing branch. - eeg['chanlocs'] = [] - - # Force the no-pop_select fallback with a non-ImportError so the - # mask-compositing branch runs (this is where the bug lived). - import eegprep - - original = eegprep.pop_select - - def failing_pop_select(*args, **kwargs): - raise RuntimeError("simulated pop_select failure") - - eegprep.pop_select = failing_pop_select - try: - result = clean_flatlines(eeg, max_flatline_duration=1.0) - finally: - eegprep.pop_select = original - - mask = result['etc']['clean_channel_mask'] - # Original mask had 32 True entries; after compositing exactly channel 5 - # (the flatline) must be False and the rest True. - self.assertEqual(mask.shape[0], 32) - self.assertFalse(mask[5]) - self.assertEqual(int(np.sum(~mask)), 1) - - -class TestCleanFlatlinesNoOpPath(DebuggableTestCase): - """No-operation path test cases for clean_flatlines function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_flatlines_no_op_no_flatlines_detected(self): - """Test clean_flatlines when no flatlines are detected.""" - eeg_no_flatlines = self.test_eeg.copy() - eeg_no_flatlines['data'] = np.random.randn(32, 1000, 10) - - result = clean_flatlines(eeg_no_flatlines) - - # Should not modify the data - self.assertEqual(result['nbchan'], eeg_no_flatlines['nbchan']) - np.testing.assert_array_equal(result['data'], eeg_no_flatlines['data']) - - def test_clean_flatlines_no_op_all_channels_flagged(self): - """Test clean_flatlines when all channels are flagged.""" - eeg_all_flagged = self.test_eeg.copy() - eeg_all_flagged['data'] = np.zeros_like(eeg_all_flagged['data']) - - result = clean_flatlines(eeg_all_flagged, max_flatline_duration=1.0) - - # Should not remove all channels (warning case) - self.assertEqual(result['nbchan'], eeg_all_flagged['nbchan']) - - def test_clean_flatlines_no_op_high_jitter_threshold(self): - """Test clean_flatlines with very high jitter threshold.""" - eeg_high_jitter = self.test_eeg.copy() - eeg_high_jitter['data'][5, :, :] = 1.0 # Create flatline - - result = clean_flatlines(eeg_high_jitter, max_allowed_jitter=1e6) - - # Should not remove channels with high jitter tolerance - self.assertEqual(result['nbchan'], eeg_high_jitter['nbchan']) - - def test_clean_flatlines_no_op_very_short_data(self): - """Test clean_flatlines with very short data.""" - eeg_short = self.test_eeg.copy() - eeg_short['data'] = np.random.randn(32, 10, 1) # Very short - eeg_short['pnts'] = 10 - eeg_short['trials'] = 1 - - result = clean_flatlines(eeg_short) - - # Should handle very short data gracefully - self.assertEqual(result['pnts'], 10) - - def test_clean_flatlines_no_op_boundary_conditions(self): - """Test clean_flatlines with boundary conditions.""" - eeg_boundary = self.test_eeg.copy() - # Create flatlines at boundaries - eeg_boundary['data'][5, 0:100, :] = 1.0 # Start boundary - eeg_boundary['data'][10, 900:1000, :] = 0.0 # End boundary - - result = clean_flatlines(eeg_boundary, max_flatline_duration=0.5) - - # Note: Current implementation may not detect flatlines as expected - # Test that the function completes without error - self.assertIsInstance(result, dict) - - -class TestCleanFlatlinesIntegration(DebuggableTestCase): - """Integration test cases for clean_flatlines function.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_eeg = create_test_eeg() - - def test_clean_flatlines_preserves_structure(self): - """Test that clean_flatlines preserves EEG structure.""" - original_eeg = self.test_eeg.copy() - - result = clean_flatlines(original_eeg) - - # Check that all required fields are preserved - required_fields = ['srate', 'pnts', 'trials', 'xmin', 'xmax', 'times'] - for field in required_fields: - self.assertIn(field, result) - if isinstance(original_eeg[field], np.ndarray): - np.testing.assert_array_equal(result[field], original_eeg[field]) - else: - self.assertEqual(result[field], original_eeg[field]) - - def test_clean_flatlines_chanlocs_consistency(self): - """Test that clean_flatlines maintains chanlocs consistency.""" - eeg_with_chanlocs = self.test_eeg.copy() - eeg_with_chanlocs['data'][5, :, :] = 1.0 # Create flatline - - result = clean_flatlines(eeg_with_chanlocs, max_flatline_duration=1.0) - - # Check that chanlocs matches the remaining channels - if result['nbchan'] < eeg_with_chanlocs['nbchan']: - self.assertEqual(len(result['chanlocs']), result['nbchan']) - else: - self.assertEqual(len(result['chanlocs']), len(eeg_with_chanlocs['chanlocs'])) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_clean_rawdata.py b/tests/test_clean_rawdata.py deleted file mode 100644 index 364d0c15..00000000 --- a/tests/test_clean_rawdata.py +++ /dev/null @@ -1,315 +0,0 @@ -import logging -import os -import unittest - -if os.getenv('EEGPREP_SKIP_MATLAB') == '1': - raise unittest.SkipTest("MATLAB not available") -from copy import deepcopy - -import numpy as np - -from eegprep import ( - clean_artifacts, - clean_asr, - clean_channels, - clean_channels_nolocs, - clean_drifts, - clean_flatlines, - clean_windows, - pop_loadset, -) -from eegprep.functions.adminfunc import eeglabcompat -from eegprep.utils.testing import DebuggableTestCase, compare_eeg, is_debug, use_64bit_eeg_options - -logger = logging.getLogger(__name__) - -# where the test resources -web_root = 'https://sccntestdatasets.s3.us-east-2.amazonaws.com/' -local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - - -def ensure_file(fname: str) -> str: - """Download a file if it does not exist and return the local path.""" - full_url = f"{web_root}{fname}" - local_file = os.path.abspath(f"{local_url}{fname}") - if not os.path.exists(local_file): - from urllib.request import urlretrieve - - urlretrieve(full_url, local_file) - return local_file - - -@unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") -class TestMATLABAccess(unittest.TestCase): - def setUp(self): - try: - self.eeglab = eeglabcompat.get_eeglab('MAT') - self.EEG = pop_loadset(ensure_file('FlankerTest.set')) - except ImportError as e: - self.skipTest(f"MATLAB not available: {e}") - - def test_basic(self): - self.assertEqual(self.eeglab.sqrt(4.0), 2.0, 'MATLAB sqrt() failed') - - def test_eeglab_presence(self): - eeglabcompat.eeg_checkset(self.EEG, eeglab=self.eeglab) - - -class TestCleanFlatlines(unittest.TestCase): - def setUp(self): - # download file - self.EEG = pop_loadset(ensure_file('FlankerTest.set')) - self.EEG['data'][5, 1000:2000] = 3.5 # this should trigger - self.EEG['data'][7, 2000:2100] = 4.5 # this should not (too short) - self.EEG['data'][9, 3000:4000] = 5.5 # should trigger too - self.EEG['data'][15, 5000:10000] = np.random.randn(5000) * 1e-10 # should not trigger (too large amplitude) - self.expected = self.EEG['data'][~np.isin(np.arange(self.EEG['nbchan']), [5, 9]), :] - - def test_clean_flatlines(self): - cleaned_EEG = clean_flatlines(deepcopy(self.EEG), 3.5) - np.testing.assert_almost_equal(cleaned_EEG['data'], self.expected, err_msg='clean_flatlines() test failed') - - -@unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") -class TestUtilFuncs(DebuggableTestCase): - def setUp(self): - self.eeglab = eeglabcompat.get_eeglab('MAT') - - def test_design_kaiser(self): - from eegprep.plugins.firfilt.design import design_kaiser - - observed = design_kaiser(0.06, 0.08, 75, True) - expected = np.asarray(self.eeglab.design_kaiser(0.06, 0.08, 75.0, True)) - np.testing.assert_almost_equal(observed.flatten(), expected.flatten(), err_msg='design_kaiser() test failed') - - def test_design_fir_default_wnd(self): - from eegprep.plugins.firfilt.design import design_fir - - observed = design_fir(234, [0.0, 0.06, 0.08, 1.0], [0, 0, 1, 1]) - expected = np.asarray( - self.eeglab.design_fir(234.0, np.asarray([0.0, 0.06, 0.08, 1.0]), np.asarray([0.0, 0.0, 1.0, 1.0])) - ) - np.testing.assert_almost_equal( - observed.flatten(), expected.flatten(), err_msg='test_design_fir_default_wnd() test failed' - ) - - def test_design_fir_custom_wnd(self): - from eegprep.plugins.firfilt.design import design_fir, design_kaiser - - wnd = design_kaiser(0.06, 0.08, 75.0, True) - observed = design_fir(234, [0.0, 0.06, 0.08, 1.0], [0, 0, 1.0, 1.0], w=wnd) - expected = np.asarray( - self.eeglab.design_fir( - 234.0, np.asarray([0.0, 0.06, 0.08, 1.0]), np.asarray([0, 0, 1.0, 1.0]), np.asarray([]), wnd - ) - ) - np.testing.assert_almost_equal( - observed.flatten(), expected.flatten(), err_msg='test_design_fir_custom_wnd() test failed' - ) - - def test_block_geometric_median(self): - from eegprep.plugins.clean_rawdata.private.stats import block_geometric_median - - np.random.seed(42) - # generate heavy-tailed data with non-zero centroid and apply random rotation - df = 3 # degrees of freedom for t-distribution - center = np.arange(1, 33) # non-zero centroid vector - # random noise transform - R = np.random.randn(32, 32) - noise = np.random.standard_t(df, size=(5007, 32)) - X = noise.dot(R) + center - observed = block_geometric_median(X, 10) - expected = np.asarray(self.eeglab.block_geometric_median(X, 10.0)) - np.testing.assert_almost_equal( - observed.flatten(), expected.flatten(), err_msg='block_geometric_median() test failed' - ) - - def test_fit_eeg_distribution(self): - from eegprep.plugins.clean_rawdata.private.stats import fit_eeg_distribution - from scipy.stats import genextreme - - x = genextreme.rvs(0.1, size=5007) - observed, *_ = fit_eeg_distribution(x) # returns 4 values, for now we check only the first - expected = self.eeglab.fit_eeg_distribution(x) - # compare numbers - np.testing.assert_almost_equal(observed, expected, err_msg='fit_eeg_distribution() test failed') - - -class TestCleanDrifts(DebuggableTestCase): - def setUp(self): - self.EEG = pop_loadset(ensure_file('FlankerTest.set')) - - def test_clean_drifts(self): - eeglab = eeglabcompat.get_eeglab('MAT') - - # compare vs MATLAB - expected = eeglab.clean_drifts(self.EEG, [3, 4], 75) - cleaned1 = clean_drifts(deepcopy(self.EEG), [3, 4], 75, method='fir') - compare_eeg(cleaned1['data'], expected['data'], err_msg='clean_drifts() failed') - - # compare FFT vs FIR - cleaned2 = clean_drifts(deepcopy(self.EEG), [3, 4], 75, method='fft') - compare_eeg(cleaned1['data'], cleaned2['data'], err_msg='clean_drifts() FFT mode test failed', atol=2e-7) - - -class TestCleanChannels(DebuggableTestCase): - def setUp(self): - self.EEG = pop_loadset(ensure_file('EmotionValence.set')) - - def test_clean_channels_nolocs(self): - eeglab = eeglabcompat.get_eeglab('MAT') - cleaned, _ = clean_channels_nolocs(deepcopy(self.EEG), 0.9) - expected = eeglab.clean_channels_nolocs(self.EEG, 0.9) - compare_eeg(cleaned['data'], expected['data'], err_msg='clean_channels_nolocs() failed') - - def test_clean_channels_locs(self): - cleaned = clean_channels(deepcopy(self.EEG), 0.9) - eeglab = eeglabcompat.get_eeglab('MAT') - expected = eeglab.clean_channels(self.EEG, 0.9) - compare_eeg(cleaned['data'], expected['data'], err_msg='clean_channels() failed') - - -class TestCleanASR(DebuggableTestCase): - def setUp(self): - self.EEG = pop_loadset(ensure_file('EmotionValence.set')) - - def test_clean_asr_nowindow(self): - cleaned = clean_asr(deepcopy(self.EEG), ref_maxbadchannels='off') - eeglab = eeglabcompat.get_eeglab('MAT') - expected = eeglab.clean_asr(self.EEG, [], [], [], [], 'off') - compare_eeg( - cleaned['data'], - expected['data'], - atol=0, - rtol=1e-6, # because of eigh() precision differences - err_msg='clean_asr() failed vs MATLAB', - ) - - def test_riemannian(self): - """Test the Riemannian mode.""" - # for now this is just checking that it does not crash since we don't have - # MATLAB reference code for this - clean_asr(deepcopy(self.EEG), useriemannian='calib') - - -class TestCleanWindows(DebuggableTestCase): - def setUp(self): - self.EEG = pop_loadset(ensure_file('EmotionValence.set')) - - # ------------------------------------------------------------------ - # Inject synthetic high‑amplitude artefacts so that clean_windows() - # has something to remove. We draw 20 random stretches (200‑2000 - # samples) and corrupt a random subset of channels in each stretch - # by adding scaled Gaussian noise. - # ------------------------------------------------------------------ - rng = np.random.default_rng(seed=42) # deterministic for the test - - n_stretches = 20 - n_chans = self.EEG['nbchan'] - n_samp = self.EEG['pnts'] - - for _ in range(n_stretches): - # Random stretch length between 200 and 2000 samples (inclusive) - stretch_len = int(rng.integers(200, 2001)) - - # Random start index ensuring we stay within bounds - start = int(rng.integers(0, max(1, n_samp - stretch_len))) - stop = start + stretch_len - - # Random number of channels (at least 1, at most all channels) - n_corrupt = int(rng.integers(1, n_chans + 1)) - chan_idx = rng.choice(n_chans, size=n_corrupt, replace=False) - - # Local scale (robust): std of the windowed data - window_data = self.EEG['data'][chan_idx, start:stop] - local_std = float(np.std(window_data, ddof=0)) - if local_std == 0: - local_std = 1.0 # avoid degenerate case - - # Heavy‑tailed positive scale factor (Gamma with mean ≈3) - scale_factor = float(rng.gamma(shape=2.0, scale=1.5)) - - noise = rng.standard_normal(size=(n_corrupt, stretch_len)) * local_std * scale_factor - self.EEG['data'][chan_idx, start:stop] += noise - - def test_clean_windows(self): - cleaned, _ = clean_windows(deepcopy(self.EEG)) - eeglab = eeglabcompat.get_eeglab('MAT') - expected = eeglab.clean_windows(self.EEG) - compare_eeg(cleaned['data'], expected['data'], err_msg='clean_windows() failed vs MATLAB') - - def test_clean_windows_preserves_float64(self): - cleaned, _ = clean_windows(deepcopy(self.EEG)) - self.assertEqual(cleaned['data'].dtype, np.float64) - - -# ------------------------------------------------------------------------------ -# clean_artifacts -# ------------------------------------------------------------------------------ - - -class TestCleanArtifacts(DebuggableTestCase): - def setUp(self): - # Use the same dataset as other heavy‑duty tests - self.EEG = pop_loadset(ensure_file('EmotionValence.set')) - - def test_clean_artifacts_defaults(self): - """Compare Python clean_artifacts against MATLAB implementation (default params).""" - with use_64bit_eeg_options(): - # --- Python version --- - cleaned_py, _, _, _ = clean_artifacts(deepcopy(self.EEG)) - - # --- MATLAB reference --- - eeglab = eeglabcompat.get_eeglab('MAT') - # Call with the matching name‑value pair - expected_mat = eeglab.clean_artifacts(self.EEG) - - compare_eeg( - cleaned_py['data'], - expected_mat['data'], - rtol=0, - atol=1e-5, # limit to 1e-5 uV likely due to solver differences - err_msg='clean_artifacts() failed vs MATLAB', - ) - - -class TestCleanArtifactsAdvanced(DebuggableTestCase): - def setUp(self): - # Use the same dataset as other heavy‑duty tests - self.EEG = pop_loadset(ensure_file('eeglab_data_with_ica_tmp.set')) - - def test_clean_artifacts_alt_defaults(self): - """Compare Python clean_artifacts against MATLAB implementation (alt parameters).""" - kwargs = dict( - FlatlineCriterion=5, - ChannelCriterion=0.87, - LineNoiseCriterion=4, - Highpass=[0.25, 0.75], - BurstCriterion=20, - WindowCriterion=0.25, - WindowCriterionTolerances=[float('-inf'), 7], - ) - - # --- Python version --- - cleaned_py, _, _, _ = clean_artifacts(deepcopy(self.EEG), **kwargs) - - # --- MATLAB reference --- - eeglab = eeglabcompat.get_eeglab('MAT') - expected_mat = eeglab.clean_artifacts(self.EEG, **kwargs) - - compare_eeg( - cleaned_py['data'], - expected_mat['data'], - rtol=0, - atol=2e-5, # limit to 2e-5 uV due to solver and floating point differences - err_msg='clean_artifacts() failed vs MATLAB', - ) - - -if __name__ == "__main__": - # run TestCleanDrifts only - if is_debug(): - # put the test here that you want to run in the debugger - TestCleanASR.debugTestCase() - else: - unittest.main() diff --git a/tests/test_clean_rawdata_channel_removal.py b/tests/test_clean_rawdata_channel_removal.py deleted file mode 100644 index d47fc7b7..00000000 --- a/tests/test_clean_rawdata_channel_removal.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from eegprep.plugins.clean_rawdata.private.channel_removal import update_clean_channel_mask - - -def test_update_clean_channel_mask_composites_original_channel_mask(): - eeg = {"etc": {"clean_channel_mask": np.array([True, False, True, True])}} - removed_channels = np.array([False, True, False]) - - update_clean_channel_mask(eeg, removed_channels) - - np.testing.assert_array_equal(eeg["etc"]["clean_channel_mask"], [True, False, False, True]) - - -def test_update_clean_channel_mask_resets_mismatched_existing_mask(): - eeg = {"etc": {"clean_channel_mask": np.array([True, False, True, True])}} - removed_channels = np.array([False, True]) - - update_clean_channel_mask(eeg, removed_channels) - - np.testing.assert_array_equal(eeg["etc"]["clean_channel_mask"], [True, False]) - - -def test_update_clean_channel_mask_creates_zero_and_all_removal_masks(): - no_removal = {"etc": {}} - all_removed = {} - - update_clean_channel_mask(no_removal, np.zeros(3, dtype=bool)) - update_clean_channel_mask(all_removed, np.ones(3, dtype=bool)) - - np.testing.assert_array_equal(no_removal["etc"]["clean_channel_mask"], [True, True, True]) - np.testing.assert_array_equal(all_removed["etc"]["clean_channel_mask"], [False, False, False]) diff --git a/tests/test_clean_windows.py b/tests/test_clean_windows.py deleted file mode 100644 index 096ff313..00000000 --- a/tests/test_clean_windows.py +++ /dev/null @@ -1,512 +0,0 @@ -import unittest -import numpy as np -from unittest.mock import patch - -from eegprep.plugins.clean_rawdata.clean_windows import clean_windows - - -class TestCleanWindows(unittest.TestCase): - """Test the clean_windows function.""" - - def setUp(self): - """Set up test fixtures with synthetic EEG data.""" - np.random.seed(42) # For reproducible tests - - # Create synthetic EEG data structure - self.n_channels = 8 - self.n_samples = 2500 # 10 seconds at 250 Hz - self.srate = 250.0 - - # Create clean EEG data - self.clean_data = np.random.randn(self.n_channels, self.n_samples) * 0.5 - - # Add some realistic structure - for ch in range(self.n_channels): - # Add some low-frequency trend - t = np.linspace(0, self.n_samples / self.srate, self.n_samples) - self.clean_data[ch] += 0.2 * np.sin(2 * np.pi * 0.5 * t) - - # Add some artifacts to specific windows - self.data_with_artifacts = self.clean_data.copy() - - # Add high-amplitude artifacts to specific channels/times - self.data_with_artifacts[2, 500:750] += np.random.randn(250) * 5.0 # Large artifact - self.data_with_artifacts[5, 1500:1750] += np.random.randn(250) * 4.0 # Another artifact - - self.EEG_clean = { - 'data': self.clean_data, - 'srate': self.srate, - 'pnts': self.n_samples, - 'nbchan': self.n_channels, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - } - - self.EEG_artifacts = { - 'data': self.data_with_artifacts, - 'srate': self.srate, - 'pnts': self.n_samples, - 'nbchan': self.n_channels, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - } - - def test_basic_functionality(self): - """Test basic clean_windows functionality.""" - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy()) - - # Check that output is an EEG dict and boolean mask - self.assertIsInstance(EEG_out, dict) - self.assertIsInstance(sample_mask, np.ndarray) - self.assertEqual(sample_mask.dtype, bool) - self.assertEqual(len(sample_mask), self.n_samples) - - # Check that some artifacts were detected (mask should have False values) - self.assertFalse(np.all(sample_mask)) - - # Check that output data has fewer samples than input - self.assertLessEqual(EEG_out['pnts'], self.n_samples) - - # Check that data is finite - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - def test_window_criterion_thresholds(self): - """Test WindowCriterion thresholds with different z-score limits.""" - test_thresholds = [(-2, 3), (-3.5, 5), (-5, 7)] - - results = [] - for zthresh in test_thresholds: - with self.subTest(zthresholds=zthresh): - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy(), zthresholds=zthresh) - results.append((EEG_out, sample_mask)) - - # More lenient thresholds should retain more data - kept_pct = np.mean(sample_mask) * 100 - self.assertGreaterEqual(kept_pct, 0) - self.assertLessEqual(kept_pct, 100) - - # More lenient thresholds should generally keep more data - # Compare (-2, 3) vs (-5, 7) - strict_kept = np.mean(results[0][1]) - lenient_kept = np.mean(results[2][1]) - self.assertLessEqual(strict_kept, lenient_kept) - - def test_max_bad_channels_parameter(self): - """Test max_bad_channels parameter as both fraction and absolute count.""" - # Test as fraction - EEG_out1, mask1 = clean_windows( - self.EEG_artifacts.copy(), - max_bad_channels=0.25, # 25% of channels - ) - - # Test as absolute count - EEG_out2, mask2 = clean_windows( - self.EEG_artifacts.copy(), - max_bad_channels=2, # 2 channels - ) - - # Both should work and produce valid results - self.assertTrue(np.all(np.isfinite(EEG_out1['data']))) - self.assertTrue(np.all(np.isfinite(EEG_out2['data']))) - - # Results should be different (unless by coincidence) - if not np.array_equal(mask1, mask2): - self.assertFalse(np.array_equal(mask1, mask2)) - - def test_window_parameters(self): - """Test different window length and overlap parameters.""" - test_params = [ - {'window_len': 0.5, 'window_overlap': 0.5}, - {'window_len': 1.0, 'window_overlap': 0.66}, - {'window_len': 2.0, 'window_overlap': 0.8}, - ] - - for params in test_params: - with self.subTest(**params): - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy(), **params) - - # Should complete without errors - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - self.assertEqual(len(sample_mask), self.n_samples) - - def test_distribution_fitting_parameters(self): - """Test distribution fitting parameters.""" - EEG_out, sample_mask = clean_windows( - self.EEG_artifacts.copy(), - max_dropout_fraction=0.2, - min_clean_fraction=0.3, - truncate_quant=(0.05, 0.7), - step_sizes=(0.02, 0.02), - shape_range=np.arange(1.5, 4.0, 0.2), - ) - - # Should complete successfully - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - self.assertEqual(len(sample_mask), self.n_samples) - - def test_2d_vs_3d_data(self): - """Test with 2D (continuous) vs 3D (epoched) data.""" - # Test 2D data (continuous) - this is the main use case - EEG_2d = self.EEG_artifacts.copy() - EEG_out_2d, mask_2d = clean_windows(EEG_2d) - - self.assertEqual(len(EEG_out_2d['data'].shape), 2) - self.assertTrue(np.all(np.isfinite(EEG_out_2d['data']))) - - # Note: The function expects 2D data (channels x samples) - # 3D data would need to be handled differently or would cause an error - # This is documented behavior for clean_windows - - def test_no_windows_case(self): - """Test case where no windows are removed (all clean data).""" - # Use very lenient thresholds that should keep all data - EEG_out, sample_mask = clean_windows( - self.EEG_clean.copy(), - zthresholds=(-10, 10), # Very lenient - max_bad_channels=self.n_channels, # Allow all channels to be bad - ) - - # Should keep most/all data - kept_pct = np.mean(sample_mask) * 100 - self.assertGreaterEqual(kept_pct, 90) # Should keep at least 90% - - # Output should be valid - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - def test_all_windows_removed_case(self): - """Test case where all windows would be removed (very strict thresholds).""" - # Create data with extreme artifacts everywhere - extreme_data = self.clean_data.copy() - extreme_data += np.random.randn(*extreme_data.shape) * 10 # Add large noise everywhere - - EEG_extreme = self.EEG_artifacts.copy() - EEG_extreme['data'] = extreme_data - - # Use very strict thresholds - EEG_out, sample_mask = clean_windows( - EEG_extreme, - zthresholds=(-0.1, 0.1), # Very strict - max_bad_channels=0, # No bad channels allowed - ) - - # Should remove most data - kept_pct = np.mean(sample_mask) * 100 - self.assertLessEqual(kept_pct, 50) # Should remove at least 50% - - # Output should still be valid (even if very little data remains) - if EEG_out['pnts'] > 0: - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - def test_tolerances_array_shape_range(self): - """Test different shape_range (beta parameter) arrays.""" - # Test different shape ranges for the generalized Gaussian - shape_ranges = [np.arange(1.0, 2.5, 0.1), np.arange(2.0, 4.0, 0.2), np.array([1.5, 2.0, 2.5, 3.0])] - - for shape_range in shape_ranges: - with self.subTest(shape_range=shape_range): - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy(), shape_range=shape_range) - - # Should complete successfully - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - def test_edge_case_empty_data(self): - """Test error handling with empty data.""" - empty_EEG = {'data': np.empty((0, 0)), 'srate': 250.0, 'pnts': 0, 'nbchan': 0} - - with self.assertRaises(ValueError) as cm: - clean_windows(empty_EEG) - - self.assertIn('Empty data array', str(cm.exception)) - - def test_edge_case_single_channel(self): - """Test with single channel data.""" - single_ch_data = self.data_with_artifacts[0:1, :] # Take only first channel - single_ch_EEG = { - 'data': single_ch_data, - 'srate': self.srate, - 'pnts': self.n_samples, - 'nbchan': 1, - 'xmin': 0.0, - 'xmax': (self.n_samples - 1) / self.srate, - } - - EEG_out, sample_mask = clean_windows(single_ch_EEG) - - # Should work with single channel - self.assertEqual(EEG_out['data'].shape[0], 1) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - def test_edge_case_very_short_data(self): - """Test with very short data that might not fit a full window.""" - short_data = self.data_with_artifacts[:, :100] # Only 100 samples - short_EEG = { - 'data': short_data, - 'srate': self.srate, - 'pnts': 100, - 'nbchan': self.n_channels, - 'xmin': 0.0, - 'xmax': 99 / self.srate, - } - - # With default window_len=1.0s (250 samples), this should raise an error - with self.assertRaises(ValueError) as cm: - clean_windows(short_EEG) - - self.assertIn('Not enough data for even a single window', str(cm.exception)) - - # But with shorter window, it should work - EEG_out, sample_mask = clean_windows(short_EEG, window_len=0.2) # 50 samples - self.assertIsInstance(EEG_out, dict) - - def test_window_length_validation(self): - """Test window length parameter validation.""" - # Test zero/negative window length - with self.assertRaises(ValueError) as cm: - clean_windows(self.EEG_artifacts.copy(), window_len=0) - - self.assertIn('Window length too small', str(cm.exception)) - - with self.assertRaises(ValueError) as cm: - clean_windows(self.EEG_artifacts.copy(), window_len=-1) - - self.assertIn('Window length too small', str(cm.exception)) - - def test_window_overlap_edge_cases(self): - """Test window overlap edge cases.""" - # Test overlap >= 1 (should be handled gracefully) - EEG_out, sample_mask = clean_windows( - self.EEG_artifacts.copy(), - window_overlap=1.0, # 100% overlap - ) - - # Should work (function sets step=1 to avoid infinite loop) - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - # Test overlap > 1 - EEG_out2, sample_mask2 = clean_windows( - self.EEG_artifacts.copy(), - window_overlap=1.5, # 150% overlap - ) - - # Should also work - self.assertIsInstance(EEG_out2, dict) - self.assertTrue(np.all(np.isfinite(EEG_out2['data']))) - - def test_distribution_fitting_fallback(self): - """Test fallback when distribution fitting fails.""" - # The function has built-in fallback logic for when sigma=0 or NaN - # We can test this by creating data that might cause fitting issues - constant_data = np.ones((4, 1000)) * 0.5 # Constant data might cause sigma=0 - constant_EEG = {'data': constant_data, 'srate': 250.0, 'pnts': 1000, 'nbchan': 4, 'xmin': 0.0, 'xmax': 3.996} - - # Should complete using MAD fallback if distribution fitting fails - EEG_out, sample_mask = clean_windows(constant_EEG) - - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - def test_distribution_fitting_nan_fallback(self): - """Test fallback when distribution fitting returns NaN.""" - with patch('eegprep.plugins.clean_rawdata.private.stats.fit_eeg_distribution') as mock_fit: - # Return NaN sigma to trigger fallback - mock_fit.return_value = (1.0, np.nan, None, None) - - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy()) - - # Should complete using MAD fallback - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - def test_pop_select_integration(self): - """pop_select success path should preserve the input data dtype.""" - original_dtype = self.EEG_artifacts['data'].dtype - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy()) - - # Should complete successfully via pop_select - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - self.assertEqual(EEG_out['data'].dtype, original_dtype) - - def test_direct_rejection_preserves_dtype(self): - """Direct eeg_eegrej sample rejection should preserve the input data dtype.""" - original_dtype = self.EEG_artifacts['data'].dtype - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy()) - - # Should produce valid output using direct sample rejection. - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - self.assertEqual(EEG_out['data'].dtype, original_dtype) - - def test_clean_sample_mask_handling(self): - """Test handling of EEG.etc.clean_sample_mask field.""" - # Test with no existing mask - EEG_test = self.EEG_artifacts.copy() - EEG_out, sample_mask = clean_windows(EEG_test) - - # Should create etc.clean_sample_mask - self.assertIn('etc', EEG_out) - self.assertIn('clean_sample_mask', EEG_out['etc']) - np.testing.assert_array_equal(EEG_out['etc']['clean_sample_mask'], sample_mask) - - # Test with existing compatible mask - EEG_test2 = self.EEG_artifacts.copy() - existing_mask = np.ones(self.n_samples, dtype=bool) - existing_mask[1000:1200] = False # Some previous cleaning - EEG_test2['etc'] = {'clean_sample_mask': existing_mask} - - EEG_out2, sample_mask2 = clean_windows(EEG_test2) - - # Should update the existing mask - self.assertIn('clean_sample_mask', EEG_out2['etc']) - - # Test with existing incompatible mask - EEG_test3 = self.EEG_artifacts.copy() - incompatible_mask = np.ones(100, dtype=bool) # Wrong size - EEG_test3['etc'] = {'clean_sample_mask': incompatible_mask} - - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_windows', level='WARNING') as log: - EEG_out3, sample_mask3 = clean_windows(EEG_test3) - - # Should log warning and overwrite - self.assertTrue(any('incompatible' in msg for msg in log.output)) - np.testing.assert_array_equal(EEG_out3['etc']['clean_sample_mask'], sample_mask3) - - def test_direct_rejection_data_processing(self): - """Direct sample rejection keeps EEGLAB eegrej timing semantics.""" - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy()) - - # Check that fallback processing was applied without changing precision. - self.assertEqual(EEG_out['data'].dtype, self.EEG_artifacts['data'].dtype) - - # pnts and xmax should be updated - self.assertEqual(EEG_out['pnts'], EEG_out['data'].shape[1]) - old_duration = self.EEG_artifacts['xmax'] - self.EEG_artifacts['xmin'] - expected_xmax = self.EEG_artifacts['xmin'] + old_duration * EEG_out['pnts'] / self.EEG_artifacts['pnts'] - self.assertAlmostEqual(EEG_out['xmax'], expected_xmax, places=6) - - def test_pop_select_success_preserves_events_and_inserts_boundaries(self): - """pop_select success path keeps events and inserts boundaries at cuts. - - Mirrors EEGLAB's clean_windows.m, which only wipes event metadata in - the manual fallback branch. On the success path, pop_select / eeg_eegrej - shift event latencies and insert a 'boundary' event at each cut with - duration equal to the excised sample count. - """ - EEG_in = self.EEG_artifacts.copy() - # Pre-populate events at known sample latencies, covering survivors, - # an event inside an artifact region, and a marker after the second - # artifact so we can verify post-cut latency shifting. - EEG_in['event'] = [ - {'type': 'S1', 'latency': 100.0, 'duration': 0.0}, - {'type': 'S2', 'latency': 600.0, 'duration': 0.0}, # inside first artifact (500:750) - {'type': 'S3', 'latency': 1000.0, 'duration': 0.0}, - {'type': 'S4', 'latency': 2000.0, 'duration': 0.0}, - ] - EEG_in['urevent'] = [dict(ev) for ev in EEG_in['event']] - for i, ev in enumerate(EEG_in['event'], start=1): - ev['urevent'] = i - EEG_in['nbchan'] = self.n_channels - EEG_in['trials'] = 1 - - EEG_out, sample_mask = clean_windows(EEG_in) - - # The success path should not wipe event metadata. - self.assertIn('event', EEG_out) - events = list(EEG_out['event']) - self.assertGreater(len(events), 0) - - # At least one boundary event should have been inserted (artifacts - # produced cuts, so sample_mask has False stretches). - boundary_events = [ev for ev in events if str(ev.get('type', '')).lower() == 'boundary'] - self.assertGreaterEqual(len(boundary_events), 1) - - # Surviving events must lie within the new sample grid. - new_pnts = EEG_out['pnts'] - for ev in events: - if 'latency' in ev: - self.assertGreaterEqual(float(ev['latency']), 0.0) - self.assertLessEqual(float(ev['latency']), float(new_pnts) + 1) - - # Boundary durations should be positive and not exceed the total - # number of removed samples. - total_removed = int(np.sum(~sample_mask)) - for ev in boundary_events: - self.assertGreater(float(ev.get('duration', 0.0)), 0.0) - self.assertLessEqual(float(ev.get('duration', 0.0)), float(total_removed)) - - self.assertEqual(EEG_out['data'].dtype, EEG_in['data'].dtype) - - def test_logging_output(self): - """Test that appropriate logging messages are generated.""" - with self.assertLogs('eegprep.plugins.clean_rawdata.clean_windows', level='INFO') as log: - clean_windows(self.EEG_artifacts.copy()) - - # Should log threshold determination and completion - self.assertTrue(any('Determining time window rejection thresholds' in msg for msg in log.output)) - self.assertTrue(any('done.' in msg for msg in log.output)) - self.assertTrue(any('Keeping' in msg and '% (' in msg and 'seconds) of the data' in msg for msg in log.output)) - - def test_different_data_types(self): - """Test with different input data types.""" - for dtype in [np.float32, np.float64, np.int16, np.int32]: - with self.subTest(dtype=dtype): - EEG_test = self.EEG_artifacts.copy() - EEG_test['data'] = EEG_test['data'].astype(dtype) - - EEG_out, sample_mask = clean_windows(EEG_test) - - # Should work regardless of input type - self.assertIsInstance(EEG_out, dict) - self.assertTrue(np.all(np.isfinite(EEG_out['data']))) - - self.assertTrue(np.issubdtype(EEG_out['data'].dtype, np.floating)) - if np.issubdtype(np.dtype(dtype), np.floating): - self.assertEqual(EEG_out['data'].dtype, np.dtype(dtype)) - else: - self.assertEqual(EEG_out['data'].dtype, np.dtype(np.float64)) - - def test_sample_mask_consistency(self): - """Test that sample_mask correctly corresponds to retained data.""" - EEG_out, sample_mask = clean_windows(self.EEG_artifacts.copy()) - - # The sample_mask should have the same length as original data - self.assertEqual(len(sample_mask), self.n_samples) - - # The number of True values should relate to the output data size - # (Note: pop_select might do additional processing, so exact equality may not hold) - n_kept_samples = np.sum(sample_mask) - - # At minimum, output should not have more samples than the mask indicates - self.assertLessEqual(EEG_out['pnts'], n_kept_samples) - - # Check that sample_mask is boolean - self.assertEqual(sample_mask.dtype, bool) - - def test_does_not_mutate_input(self): - """clean_windows must not mutate the caller's EEG dict or data array.""" - EEG_in = self.EEG_artifacts - original_data = EEG_in['data'].copy() - original_dtype = EEG_in['data'].dtype - original_keys = set(EEG_in.keys()) - - EEG_out, _ = clean_windows(EEG_in) - - # Caller's data array is unchanged in value, dtype, and shape. - self.assertTrue(np.array_equal(original_data, EEG_in['data'])) - self.assertEqual(EEG_in['data'].dtype, original_dtype) - # No new keys (e.g. 'etc') were injected into the caller's dict. - self.assertEqual(set(EEG_in.keys()), original_keys) - self.assertNotIn('etc', EEG_in) - # Output is a distinct object from the input. - self.assertIsNot(EEG_out, EEG_in) - self.assertIsNot(EEG_out['data'], EEG_in['data']) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_console_workspace.py b/tests/test_console_workspace.py index 8d51c7ef..8f54d861 100644 --- a/tests/test_console_workspace.py +++ b/tests/test_console_workspace.py @@ -729,7 +729,6 @@ def fake_run_in_terminal(callback): [ ("pop_adjustevents", "eegprep.functions.popfunc.pop_adjustevents.pop_adjustevents"), ("pop_chanedit", "eegprep.functions.popfunc.pop_chanedit.pop_chanedit"), - ("pop_clean_rawdata", "eegprep.plugins.clean_rawdata.pop_clean_rawdata.pop_clean_rawdata"), ("pop_comments", "eegprep.functions.popfunc.pop_comments.pop_comments"), ("pop_editset", "eegprep.functions.popfunc.pop_editset.pop_editset"), ("pop_editeventfield", "eegprep.functions.popfunc.pop_editeventfield.pop_editeventfield"), @@ -742,8 +741,6 @@ def fake_run_in_terminal(callback): ("pop_runica", "eegprep.functions.popfunc.pop_runica.pop_runica"), ("pop_select", "eegprep.functions.popfunc.pop_select.pop_select"), ("pop_selectevent", "eegprep.functions.popfunc.pop_selectevent.pop_selectevent"), - ("pop_iclabel", "eegprep.plugins.ICLabel.pop_iclabel.pop_iclabel"), - ("pop_icflag", "eegprep.plugins.ICLabel.pop_icflag.pop_icflag"), ("pop_subcomp", "eegprep.functions.popfunc.pop_subcomp.pop_subcomp"), ], ) @@ -751,7 +748,6 @@ def test_gui_pop_action_warning_output_follows_echoed_command(action, patch_targ from eegprep.functions.guifunc.menu_actions import MenuActionDispatcher newset_actions = { - "pop_clean_rawdata", "pop_epoch", "pop_interp", "pop_reref", diff --git a/tests/test_eeg_autocorr.py b/tests/test_eeg_autocorr.py deleted file mode 100644 index 08cd0634..00000000 --- a/tests/test_eeg_autocorr.py +++ /dev/null @@ -1,387 +0,0 @@ -""" -Test suite for eeg_autocorr.py with MATLAB parity validation. - -This module tests the eeg_autocorr function which computes autocorrelation -of ICA components for EEG data. -""" - -# Disable multithreading for deterministic numerical results -import os - -os.environ["OMP_NUM_THREADS"] = "1" -os.environ["MKL_NUM_THREADS"] = "1" -os.environ["NUMEXPR_NUM_THREADS"] = "1" -os.environ["OPENBLAS_NUM_THREADS"] = "1" -os.environ["VECLIB_MAXIMUM_THREADS"] = "1" - -import unittest -import sys -import numpy as np -import warnings - -# Add src to path for imports -sys.path.insert(0, 'src') -from eegprep.plugins.ICLabel.eeg_autocorr import eeg_autocorr -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab -from eegprep.utils.testing import DebuggableTestCase - - -@unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") -class TestEegAutocorr(DebuggableTestCase): - """Test cases for eeg_autocorr function.""" - - def setUp(self): - """Set up test fixtures.""" - # Set up MATLAB compatibility for parity tests - try: - self.eeglab = get_eeglab() - self.matlab_available = True - except Exception: - self.matlab_available = False - - def create_test_eeg(self, ncomp=10, pnts=1000, trials=1, srate=256): - """Create a test EEG structure with ICA data.""" - # Create realistic ICA activations - np.random.seed(42) # For reproducible tests - nbchan = ncomp # Use same number of channels as components for valid ICA structure - - # Data and ICA shapes depend on trials - if trials == 1: - icaact = np.random.randn(ncomp, pnts).astype(np.float64) - data = np.random.randn(nbchan, pnts).astype(np.float64) - else: - icaact = np.random.randn(ncomp, pnts, trials).astype(np.float64) - data = np.random.randn(nbchan, pnts, trials).astype(np.float64) - - # Create channel locations - chanlocs = np.array( - [ - {'labels': f'E{i + 1}', 'X': 0.0, 'Y': 0.0, 'Z': 0.0, 'theta': 0.0, 'radius': 0.0, 'type': 'EEG'} - for i in range(nbchan) - ] - ) - - # ICA matrices - consistent dimensions for MATLAB compatibility - # icaweights: (ncomp, nbchan), icasphere: (nbchan, nbchan) - # icaweights * icasphere * data(icachansind,:) should work - icaweights = np.eye(ncomp, nbchan).astype(np.float64) - icasphere = np.eye(nbchan).astype(np.float64) - icawinv = np.eye(nbchan, ncomp).astype(np.float64) - icachansind = np.arange(1, nbchan + 1, dtype=np.float64) # 1-based, float for MATLAB - - return { - 'icaact': icaact, - 'pnts': pnts, - 'srate': srate, - 'trials': trials, - 'nbchan': nbchan, - 'icaweights': icaweights, - 'icasphere': icasphere, - 'icawinv': icawinv, - 'icachansind': icachansind, - 'data': data, - 'xmin': 0.0, - 'xmax': (pnts - 1) / srate, - 'times': np.linspace(0, (pnts - 1) / srate * 1000, pnts), # in ms - 'chanlocs': chanlocs, - 'urchanlocs': np.array([]), - 'chaninfo': {}, - 'ref': 'common', - 'history': '', - 'saved': 'no', - 'etc': {}, - } - - def test_basic_autocorrelation(self): - """Test basic autocorrelation computation.""" - EEG = self.create_test_eeg(ncomp=5, pnts=512, srate=256) - - result = eeg_autocorr(EEG) - - # Check output shape - expected_samples = 100 # Resampled to 100 Hz - 1 (first sample removed) - self.assertEqual(result.shape, (5, expected_samples)) - - # Check that all values are finite - self.assertTrue(np.all(np.isfinite(result))) - - # Check data type - self.assertTrue(result.dtype == np.float32 or result.dtype == np.float64) - - def test_default_pct_data(self): - """Test default pct_data parameter.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - # Test with default pct_data (should be 100) - result1 = eeg_autocorr(EEG) - result2 = eeg_autocorr(EEG, pct_data=100) - - # Octave loading in setUp affects numerical precision - # Observed: max_abs ~1e-08, max_rel ~1e-05 - np.testing.assert_allclose(result1, result2, rtol=2e-5, atol=2e-8) - - def test_explicit_pct_data(self): - """Test explicit pct_data parameter.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - # Test with explicit pct_data - result = eeg_autocorr(EEG, pct_data=50) - - # Should still produce same shape output - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_different_sample_rates(self): - """Test with different sampling rates.""" - test_cases = [ - {'srate': 128, 'expected_samples': 100}, - {'srate': 256, 'expected_samples': 100}, - {'srate': 500, 'expected_samples': 100}, - {'srate': 1000, 'expected_samples': 100}, - ] - - for case in test_cases: - with self.subTest(srate=case['srate']): - EEG = self.create_test_eeg(ncomp=2, pnts=512, srate=case['srate']) - result = eeg_autocorr(EEG) - - self.assertEqual(result.shape[0], 2) - self.assertEqual(result.shape[1], case['expected_samples']) - self.assertTrue(np.all(np.isfinite(result))) - - def test_short_data_padding(self): - """Test case where pnts < srate (requires padding).""" - EEG = self.create_test_eeg(ncomp=3, pnts=100, srate=256) # pnts < srate - - result = eeg_autocorr(EEG) - - # Should still produce expected output shape - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_long_data_truncation(self): - """Test case where pnts > srate (requires truncation).""" - EEG = self.create_test_eeg(ncomp=3, pnts=1000, srate=256) # pnts > srate - - result = eeg_autocorr(EEG) - - # Should still produce expected output shape - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_single_component(self): - """Test with single ICA component.""" - EEG = self.create_test_eeg(ncomp=1, pnts=512, srate=256) - - result = eeg_autocorr(EEG) - - self.assertEqual(result.shape, (1, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_many_components(self): - """Test with many ICA components.""" - EEG = self.create_test_eeg(ncomp=50, pnts=512, srate=256) - - result = eeg_autocorr(EEG) - - self.assertEqual(result.shape, (50, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_icaact_conversion_to_float32(self): - """Test that icaact is converted to float32.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - # Start with float64 data - EEG['icaact'] = EEG['icaact'].astype(np.float64) - - original_dtype = EEG['icaact'].dtype - self.assertEqual(original_dtype, np.float64) - - result = eeg_autocorr(EEG) - - # After processing, icaact should be float32 - self.assertEqual(EEG['icaact'].dtype, np.float32) - self.assertTrue(np.all(np.isfinite(result))) - - def test_zero_component(self): - """Test with zero-valued component (edge case).""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - # Make one component all zeros (handle 2D for trials=1) - if EEG['icaact'].ndim == 2: - EEG['icaact'][1, :] = 0 - else: - EEG['icaact'][1, :, :] = 0 - - result = eeg_autocorr(EEG) - - self.assertEqual(result.shape, (3, 100)) - # Zero component should produce NaN or inf values after normalization - # but the function should handle this gracefully - self.assertTrue(np.all(np.isfinite(result[0, :]))) # First component should be fine - self.assertTrue(np.all(np.isfinite(result[2, :]))) # Third component should be fine - # Second component (zero) might have NaN or inf, which is expected - - def test_fft_size_calculation(self): - """Test that FFT size is calculated correctly.""" - # Test with different data lengths to verify nfft calculation - test_cases = [ - {'pnts': 100, 'expected_nfft_min': 128}, # 2^7 - {'pnts': 256, 'expected_nfft_min': 512}, # 2^9 - {'pnts': 500, 'expected_nfft_min': 1024}, # 2^10 - {'pnts': 1000, 'expected_nfft_min': 2048}, # 2^11 - ] - - for case in test_cases: - with self.subTest(pnts=case['pnts']): - EEG = self.create_test_eeg(ncomp=2, pnts=case['pnts'], srate=256) - - # The function should work regardless of FFT size - result = eeg_autocorr(EEG) - - self.assertEqual(result.shape, (2, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_normalization_by_zero_tap(self): - """Test that autocorrelation is properly normalized by zero-tap.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - result = eeg_autocorr(EEG) - - # After normalization and resampling, we can't directly check for 1.0 at zero-tap - # since the first sample is removed, but values should be reasonable - self.assertTrue(np.all(np.abs(result) <= 10)) # Reasonable range after normalization - - def test_resampling_consistency(self): - """Test that resampling to 100 Hz is consistent.""" - # Test with different original sampling rates - srates = [128, 256, 512, 1000] - - for srate in srates: - with self.subTest(srate=srate): - EEG = self.create_test_eeg(ncomp=2, pnts=512, srate=srate) - result = eeg_autocorr(EEG) - - # All should resample to 100 samples (100 Hz - 1) - self.assertEqual(result.shape[1], 100) - - def test_multiple_trials(self): - """Test with multiple trials (though current implementation doesn't use 3rd dimension).""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, trials=5, srate=128) - - result = eeg_autocorr(EEG) - - # Should still work with multiple trials - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_parity_basic_autocorr(self): - """Test parity with MATLAB for basic autocorrelation using real ICA data.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Load real EEG dataset with ICA - from eegprep.functions.popfunc.pop_loadset import pop_loadset - import os - - test_file = os.path.join(os.path.dirname(__file__), '..', 'data', 'eeglab_data_with_ica_tmp.set') - if not os.path.exists(test_file): - self.skipTest(f"Test file not found: {test_file}") - - EEG = pop_loadset(test_file) - - # Python result - py_result = eeg_autocorr(EEG.copy()) - - # MATLAB result - ml_result = self.eeglab.eeg_autocorr(EEG.copy()) - - # Compare results - self.assertEqual(py_result.shape, ml_result.shape) - np.testing.assert_allclose(py_result, ml_result, rtol=1e-5, atol=1e-8) - - def test_parity_with_real_data(self): - """Test parity with MATLAB using real ICA data with different pct_data values.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Load real EEG dataset with ICA - from eegprep.functions.popfunc.pop_loadset import pop_loadset - import os - - test_file = os.path.join(os.path.dirname(__file__), '..', 'data', 'eeglab_data_with_ica_tmp.set') - if not os.path.exists(test_file): - self.skipTest(f"Test file not found: {test_file}") - - EEG = pop_loadset(test_file) - - # Test with different pct_data values - for pct_data in [50, 100]: - with self.subTest(pct_data=pct_data): - py_result = eeg_autocorr(EEG.copy(), pct_data=pct_data) - ml_result = self.eeglab.eeg_autocorr(EEG.copy(), pct_data) - - self.assertEqual(py_result.shape, ml_result.shape) - np.testing.assert_allclose(py_result, ml_result, rtol=1e-5, atol=1e-8) - - def test_edge_case_very_short_data(self): - """Test edge case with very short data.""" - EEG = self.create_test_eeg(ncomp=2, pnts=10, srate=256) - - result = eeg_autocorr(EEG) - - # Should still produce output - self.assertEqual(result.shape, (2, 100)) - # May contain NaN or inf due to very short data, but should not crash - - def test_edge_case_high_srate(self): - """Test edge case with very high sampling rate.""" - EEG = self.create_test_eeg(ncomp=2, pnts=1000, srate=2000) - - result = eeg_autocorr(EEG) - - # Should still produce 100 samples output - self.assertEqual(result.shape, (2, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_input_modification(self): - """Test that function modifies input EEG structure (icaact dtype).""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - original_dtype = EEG['icaact'].dtype - - eeg_autocorr(EEG) - - # Function should modify icaact dtype to float32 - self.assertEqual(EEG['icaact'].dtype, np.float32) - if original_dtype != np.float32: - self.assertNotEqual(EEG['icaact'].dtype, original_dtype) - - def test_deterministic_output(self): - """Test that function produces deterministic output for same input.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - # Make copies to avoid modification effects - EEG1 = {key: value.copy() if isinstance(value, np.ndarray) else value for key, value in EEG.items()} - EEG2 = {key: value.copy() if isinstance(value, np.ndarray) else value for key, value in EEG.items()} - - result1 = eeg_autocorr(EEG1) - result2 = eeg_autocorr(EEG2) - - # Octave loading in setUp affects numerical precision - # Observed: max_abs ~1e-08, max_rel ~1e-05 - np.testing.assert_allclose(result1, result2, rtol=2e-5, atol=2e-8) - - def test_memory_efficiency(self): - """Test that function works with larger datasets.""" - # Test with larger dataset to check memory efficiency - EEG = self.create_test_eeg(ncomp=100, pnts=2000, srate=1000) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") # Ignore potential memory warnings - result = eeg_autocorr(EEG) - - self.assertEqual(result.shape, (100, 100)) - # Don't check all finite for large datasets as it might be slow - self.assertTrue(result.dtype in [np.float32, np.float64]) # Either is acceptable - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_eeg_autocorr_fftw.py b/tests/test_eeg_autocorr_fftw.py deleted file mode 100644 index 8151dffb..00000000 --- a/tests/test_eeg_autocorr_fftw.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -Test suite for eeg_autocorr_fftw.py with MATLAB parity validation. - -This module tests the eeg_autocorr_fftw function which computes autocorrelation -of ICA components using FFTW-optimized FFT operations. -""" - -# Disable multithreading for deterministic numerical results -import os - -os.environ["OMP_NUM_THREADS"] = "1" -os.environ["MKL_NUM_THREADS"] = "1" -os.environ["NUMEXPR_NUM_THREADS"] = "1" -os.environ["OPENBLAS_NUM_THREADS"] = "1" -os.environ["VECLIB_MAXIMUM_THREADS"] = "1" - -import unittest -import sys -import numpy as np -import warnings - -# Add src to path for imports -sys.path.insert(0, 'src') -from eegprep.plugins.ICLabel.eeg_autocorr_fftw import eeg_autocorr_fftw -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab -from eegprep.utils.testing import DebuggableTestCase - - -class TestEegAutocorrFftw(DebuggableTestCase): - """Test cases for eeg_autocorr_fftw function.""" - - def setUp(self): - """Set up test fixtures.""" - # Set up MATLAB compatibility for parity tests - try: - self.eeglab = get_eeglab() - self.matlab_available = True - except Exception: - self.matlab_available = False - - def create_test_eeg(self, ncomp=10, pnts=1000, trials=1, srate=256): - """Create a test EEG structure with ICA data.""" - # Create realistic ICA activations - np.random.seed(42) # For reproducible tests - icaact = np.random.randn(ncomp, pnts, trials).astype( - np.float64 - ) # Use float64 to match MATLAB default precision - - return { - 'icaact': icaact, - 'pnts': pnts, - 'srate': srate, - 'trials': trials, - 'nbchan': 64, # Original channels before ICA - 'icaweights': np.random.randn(ncomp, 64).astype(np.float64), - 'icasphere': np.random.randn(64, 64).astype(np.float64), - } - - def test_basic_autocorrelation_fftw(self): - """Test basic autocorrelation computation using FFTW.""" - EEG = self.create_test_eeg(ncomp=5, pnts=512, srate=256) - - result = eeg_autocorr_fftw(EEG) - - # Check output shape - should be 100 samples (101 - 1) - expected_samples = 100 # Resampled to 100 Hz, first sample removed - self.assertEqual(result.shape, (5, expected_samples)) - - # Check that all values are finite - self.assertTrue(np.all(np.isfinite(result))) - - # Check data type - self.assertTrue(result.dtype in [np.float32, np.float64]) - - def test_default_pct_data(self): - """Test default pct_data parameter.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - # Test with default pct_data (should be 100) - result1 = eeg_autocorr_fftw(EEG) - result2 = eeg_autocorr_fftw(EEG, pct_data=100) - - # Octave loading in setUp can affect numerical precision - # Use tolerance to account for minor differences - np.testing.assert_allclose(result1, result2, rtol=2e-5, atol=2e-8) - - def test_explicit_pct_data(self): - """Test explicit pct_data parameter.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - # Test with explicit pct_data - result = eeg_autocorr_fftw(EEG, pct_data=50) - - # Should still produce same shape output - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_different_sample_rates(self): - """Test with different sampling rates.""" - test_cases = [ - {'srate': 128, 'expected_samples': 100}, - {'srate': 256, 'expected_samples': 100}, - {'srate': 500, 'expected_samples': 100}, - {'srate': 1000, 'expected_samples': 100}, - ] - - for case in test_cases: - with self.subTest(srate=case['srate']): - EEG = self.create_test_eeg(ncomp=2, pnts=512, srate=case['srate']) - result = eeg_autocorr_fftw(EEG) - - self.assertEqual(result.shape[0], 2) - self.assertEqual(result.shape[1], case['expected_samples']) - self.assertTrue(np.all(np.isfinite(result))) - - def test_short_data_padding(self): - """Test case where pnts < srate (requires padding).""" - EEG = self.create_test_eeg(ncomp=3, pnts=100, srate=256) # pnts < srate - - result = eeg_autocorr_fftw(EEG) - - # Should still produce expected output shape - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_long_data_truncation(self): - """Test case where pnts > srate (requires truncation).""" - EEG = self.create_test_eeg(ncomp=3, pnts=1000, srate=256) # pnts > srate - - result = eeg_autocorr_fftw(EEG) - - # Should still produce expected output shape - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_single_component(self): - """Test with single ICA component.""" - EEG = self.create_test_eeg(ncomp=1, pnts=512, srate=256) - - result = eeg_autocorr_fftw(EEG) - - self.assertEqual(result.shape, (1, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_many_components(self): - """Test with many ICA components.""" - EEG = self.create_test_eeg(ncomp=50, pnts=512, srate=256) - - result = eeg_autocorr_fftw(EEG) - - self.assertEqual(result.shape, (50, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_multiple_trials(self): - """Test with multiple trials (should use 3rd dimension).""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, trials=5, srate=128) - - result = eeg_autocorr_fftw(EEG) - - # Should work with multiple trials - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_fft_length_calculation(self): - """Test that FFT length is calculated correctly using next_fast_len.""" - from scipy.fft import next_fast_len - - # Test with different data lengths - test_cases = [ - {'pnts': 100}, - {'pnts': 256}, - {'pnts': 500}, - {'pnts': 1000}, - ] - - for case in test_cases: - with self.subTest(pnts=case['pnts']): - EEG = self.create_test_eeg(ncomp=2, pnts=case['pnts'], srate=256) - - # Calculate expected nfft - expected_nfft = next_fast_len(2 * case['pnts'] - 1) - - # The function should work regardless of FFT size - result = eeg_autocorr_fftw(EEG) - - self.assertEqual(result.shape, (2, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - # FFT length should be optimized - self.assertGreaterEqual(expected_nfft, 2 * case['pnts'] - 1) - - def test_power_spectrum_calculation(self): - """Test that power spectrum is calculated correctly.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - result = eeg_autocorr_fftw(EEG) - - # Power spectrum calculation should produce valid autocorrelation - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_normalization_by_zero_lag(self): - """Test that autocorrelation is properly normalized by zero-lag.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - result = eeg_autocorr_fftw(EEG) - - # After normalization and resampling, values should be reasonable - self.assertTrue(np.all(np.abs(result) <= 10)) # Reasonable range after normalization - - def test_real_output(self): - """Test that output is real-valued (no complex components).""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - result = eeg_autocorr_fftw(EEG) - - # Output should be real - self.assertTrue(np.all(np.isreal(result))) - self.assertTrue(result.dtype in [np.float32, np.float64]) - - def test_resampling_consistency(self): - """Test that resampling to 100 Hz is consistent.""" - # Test with different original sampling rates - srates = [128, 256, 512, 1000] - - for srate in srates: - with self.subTest(srate=srate): - EEG = self.create_test_eeg(ncomp=2, pnts=512, srate=srate) - result = eeg_autocorr_fftw(EEG) - - # All should resample to 100 samples (101 - 1) - self.assertEqual(result.shape[1], 100) - - def test_zero_component(self): - """Test with zero-valued component (edge case).""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - # Make one component all zeros - EEG['icaact'][1, :, :] = 0 - - result = eeg_autocorr_fftw(EEG) - - self.assertEqual(result.shape, (3, 100)) - # Zero component should produce NaN or inf values after normalization - # but the function should handle this gracefully - self.assertTrue(np.all(np.isfinite(result[0, :]))) # First component should be fine - self.assertTrue(np.all(np.isfinite(result[2, :]))) # Third component should be fine - # Second component (zero) might have NaN or inf, which is expected - - def test_edge_case_very_short_data(self): - """Test edge case with very short data.""" - EEG = self.create_test_eeg(ncomp=2, pnts=10, srate=256) - - result = eeg_autocorr_fftw(EEG) - - # Should still produce output - self.assertEqual(result.shape, (2, 100)) - # May contain NaN or inf due to very short data, but should not crash - - def test_edge_case_high_srate(self): - """Test edge case with very high sampling rate.""" - EEG = self.create_test_eeg(ncomp=2, pnts=1000, srate=2000) - - result = eeg_autocorr_fftw(EEG) - - # Should still produce 100 samples output - self.assertEqual(result.shape, (2, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_deterministic_output(self): - """Test that function produces deterministic output for same input.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - # Make copies to avoid modification effects - EEG1 = {key: value.copy() if isinstance(value, np.ndarray) else value for key, value in EEG.items()} - EEG2 = {key: value.copy() if isinstance(value, np.ndarray) else value for key, value in EEG.items()} - - result1 = eeg_autocorr_fftw(EEG1) - result2 = eeg_autocorr_fftw(EEG2) - - # Octave loading in setUp can affect numerical precision - # Use tolerance to account for minor differences - np.testing.assert_allclose(result1, result2, rtol=2e-5, atol=2e-8) - - def test_memory_efficiency(self): - """Test that function works with larger datasets.""" - # Test with larger dataset to check memory efficiency - EEG = self.create_test_eeg(ncomp=100, pnts=2000, srate=1000) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") # Ignore potential memory warnings - result = eeg_autocorr_fftw(EEG) - - self.assertEqual(result.shape, (100, 100)) - # Don't check all finite for large datasets as it might be slow - self.assertTrue(result.dtype in [np.float32, np.float64]) - - def test_comparison_with_regular_autocorr(self): - """Test that FFTW version produces similar results to regular version.""" - # Import the regular autocorr function for comparison - from eegprep.plugins.ICLabel.eeg_autocorr import eeg_autocorr - - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=128) - - # Make copies to avoid modification effects - EEG_fftw = {key: value.copy() if isinstance(value, np.ndarray) else value for key, value in EEG.items()} - EEG_regular = {key: value.copy() if isinstance(value, np.ndarray) else value for key, value in EEG.items()} - - result_fftw = eeg_autocorr_fftw(EEG_fftw) - result_regular = eeg_autocorr(EEG_regular) - - # Both should have same shape - self.assertEqual(result_fftw.shape, result_regular.shape) - - # Both compute the same autocorrelation; the only divergence is that - # eeg_autocorr casts its FFT to single precision (complex64) for MATLAB - # parity while eeg_autocorr_fftw stays double precision. Observed max - # relative difference is ~3e-5, so a float-realistic tolerance still - # catches any several-percent port regression. - self.assertTrue(np.allclose(result_fftw, result_regular, rtol=1e-4, atol=1e-7)) - - def test_axis_handling_in_fft(self): - """Test that FFT operations handle axes correctly.""" - EEG = self.create_test_eeg(ncomp=3, pnts=256, trials=5, srate=128) - - result = eeg_autocorr_fftw(EEG) - - # Should handle multiple trials correctly - self.assertEqual(result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_slicing_behavior(self): - """Test the final slicing behavior ac[:, 1:101].""" - EEG = self.create_test_eeg(ncomp=2, pnts=256, srate=128) - - result = eeg_autocorr_fftw(EEG) - - # Should slice correctly to get exactly 100 samples - self.assertEqual(result.shape[1], 100) - - def test_parity_basic_autocorr_fftw(self): - """Test parity with MATLAB for basic FFTW autocorrelation.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Create test data - EEG = self.create_test_eeg(ncomp=5, pnts=512, srate=256) - - # Python result - py_result = eeg_autocorr_fftw(EEG) - - # MATLAB result (would need to save EEG structure and call MATLAB) - # This is a placeholder for the parity test structure - # ml_result = self.eeglab.eeg_autocorr_fftw(EEG) - - # For now, just verify Python result is reasonable - self.assertEqual(py_result.shape, (5, 100)) - self.assertTrue(np.all(np.isfinite(py_result))) - - def test_parity_different_srates_fftw(self): - """Test parity with MATLAB for different sampling rates.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Test with different sampling rates - for srate in [128, 256, 512]: - with self.subTest(srate=srate): - EEG = self.create_test_eeg(ncomp=3, pnts=256, srate=srate) - - py_result = eeg_autocorr_fftw(EEG) - - # Placeholder for MATLAB comparison - self.assertEqual(py_result.shape, (3, 100)) - self.assertTrue(np.all(np.isfinite(py_result))) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_eeg_autocorr_welch.py b/tests/test_eeg_autocorr_welch.py deleted file mode 100644 index 71c0d209..00000000 --- a/tests/test_eeg_autocorr_welch.py +++ /dev/null @@ -1,396 +0,0 @@ -import unittest -import numpy as np -import os - -from eegprep.plugins.ICLabel.eeg_autocorr_welch import eeg_autocorr_welch - - -class TestEegAutocorrWelch(unittest.TestCase): - """Test the eeg_autocorr_welch function. - - Note: This function has some limitations/bugs with multi-trial data and - percentage sampling. Tests focus on functionality that actually works. - """ - - @classmethod - def setUpClass(cls): - """Set up MATLAB/Octave engine for parity testing.""" - cls.matlab_available = False - cls.eeglab = None - try: - from eegprep.functions.adminfunc.eeglabcompat import get_eeglab - - cls.eeglab = get_eeglab() - cls.matlab_available = True - except Exception as e: - print(f"MATLAB not available for parity testing: {e}") - - def setUp(self): - """Set up test fixtures with synthetic EEG data.""" - np.random.seed(42) # For reproducible tests - - # Create synthetic EEG data structure - self.n_components = 5 - self.n_channels = 8 - # Use data that creates exactly one segment to avoid the reshape issue - self.n_points = 750 # Exactly 3 seconds at 250 Hz - creates single segment - self.n_trials = 1 # Single trial to avoid the axis issue in the function - self.srate = 250.0 - - # Create synthetic ICA activations - self.icaact = np.random.randn(self.n_components, self.n_points, self.n_trials) * 0.5 - - # Add some realistic structure (sine waves with noise) - for comp in range(self.n_components): - freq = 5 + comp * 3 # Different frequencies per component - t = np.linspace(0, self.n_points / self.srate, self.n_points) - for trial in range(self.n_trials): - phase = np.random.rand() * 2 * np.pi - self.icaact[comp, :, trial] += 0.3 * np.sin(2 * np.pi * freq * t + phase) - - # Create ICA weights (not used in function but needed for validation) - self.icaweights = np.random.randn(self.n_components, self.n_channels) * 0.1 - - self.EEG = { - 'icaact': self.icaact, - 'icaweights': self.icaweights, - 'pnts': self.n_points, - 'trials': self.n_trials, - 'srate': self.srate, - } - - def test_basic_functionality(self): - """Test basic autocorrelation computation.""" - result = eeg_autocorr_welch(self.EEG, pct_data=100) - - # Check output shape: should be (n_components, 100) - resampled to 100 samples/sec for 1 second - self.assertEqual(result.shape, (self.n_components, 100)) - - # Check that result is finite - self.assertTrue(np.all(np.isfinite(result))) - - # Check that autocorrelation is normalized (first lag should be 1.0 after normalization) - # Note: the function normalizes and then resamples, so we check reasonable bounds - self.assertTrue(np.all(result >= -2)) # Allow some numerical tolerance - self.assertTrue(np.all(result <= 2)) - - def test_varying_pct_data(self): - """Test with different pct_data values.""" - # Test only pct_data=100 due to bugs in the function with percentage sampling - result_100 = eeg_autocorr_welch(self.EEG, pct_data=100) - - # Shape should be correct - self.assertEqual(result_100.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result_100))) - - def test_pct_data_edge_cases(self): - """Test edge cases for pct_data parameter.""" - # Test None (should default to 100) - result_none = eeg_autocorr_welch(self.EEG, pct_data=None) - result_100 = eeg_autocorr_welch(self.EEG, pct_data=100) - np.testing.assert_allclose(result_none, result_100, atol=1e-10) - - # Test 0 (should default to 100) - result_zero = eeg_autocorr_welch(self.EEG, pct_data=0) - # check if they are equal to within 1e-10 - np.testing.assert_allclose(result_zero, result_100, atol=1e-10) - - def test_small_vs_large_pnts(self): - """Test with small vs large number of points.""" - # Test with small pnts (less than srate) - small_EEG = self.EEG.copy() - small_EEG['pnts'] = 100 # Less than srate (250) - small_EEG['icaact'] = self.icaact[:, :100, :] # Truncate data - - result_small = eeg_autocorr_welch(small_EEG, pct_data=100) - self.assertEqual(result_small.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result_small))) - - # Test with large pnts (greater than srate) - large_EEG = self.EEG.copy() - large_EEG['pnts'] = 1000 # Greater than srate (250) - large_icaact = np.random.randn(self.n_components, 1000, self.n_trials) * 0.5 - large_EEG['icaact'] = large_icaact - - result_large = eeg_autocorr_welch(large_EEG, pct_data=100) - self.assertEqual(result_large.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result_large))) - - # Results should be different due to different data lengths - self.assertFalse(np.allclose(result_small, result_large, atol=1e-10)) - - def test_normalization_at_lag0(self): - """Test that normalization properly handles lag 0.""" - # Create EEG with known autocorrelation structure - test_EEG = self.EEG.copy() - - # Create a simple signal with known autocorrelation - use single trial - n_comp, n_pnts, n_trials = 2, 750, 1 # Use single trial to avoid axis issues - test_EEG['icaweights'] = np.random.randn(n_comp, self.n_channels) - test_EEG['pnts'] = n_pnts - test_EEG['trials'] = n_trials - - # Create periodic signals - t = np.linspace(0, n_pnts / self.srate, n_pnts) - icaact = np.zeros((n_comp, n_pnts, n_trials)) - - for comp in range(n_comp): - for trial in range(n_trials): - # Simple sine wave - should have strong autocorrelation - freq = 10 # 10 Hz - icaact[comp, :, trial] = np.sin(2 * np.pi * freq * t) - - test_EEG['icaact'] = icaact - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - - # Check that autocorrelation has reasonable structure - # For periodic signals, autocorrelation should show periodic structure - self.assertEqual(result.shape, (n_comp, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - # The autocorrelation should have some structure (not all zeros) - self.assertTrue(np.any(np.abs(result) > 0.1)) - - def test_dtype_and_shape_consistency(self): - """Test data type and shape consistency.""" - # Test with different data types - for dtype in [np.float32, np.float64]: - with self.subTest(dtype=dtype): - test_EEG = self.EEG.copy() - test_EEG['icaact'] = test_EEG['icaact'].astype(dtype) - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - - # Output should be float (numpy default) - self.assertTrue(np.issubdtype(result.dtype, np.floating)) - self.assertEqual(result.shape, (self.n_components, 100)) - - def test_different_sampling_rates(self): - """Test with different sampling rates.""" - test_srates = [100, 128, 200, 256, 500] - - for srate in test_srates: - with self.subTest(srate=srate): - test_EEG = self.EEG.copy() - test_EEG['srate'] = srate - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - - # Output should always be resampled to 100 samples/sec for 1 second - self.assertEqual(result.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_different_component_counts(self): - """Test with different numbers of components.""" - for n_comp in [1, 3, 10, 20]: - with self.subTest(n_components=n_comp): - test_EEG = self.EEG.copy() - test_EEG['icaweights'] = np.random.randn(n_comp, self.n_channels) - test_EEG['icaact'] = np.random.randn(n_comp, self.n_points, self.n_trials) * 0.5 - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - - self.assertEqual(result.shape, (n_comp, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_different_trial_counts(self): - """Test with different numbers of trials.""" - # Only test single trial to avoid the axis issue in the function - for n_trials in [1]: - with self.subTest(n_trials=n_trials): - test_EEG = self.EEG.copy() - test_EEG['trials'] = n_trials - test_EEG['icaact'] = np.random.randn(self.n_components, self.n_points, n_trials) * 0.5 - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - - self.assertEqual(result.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_random_seed_determinism(self): - """Test that random seed produces deterministic results.""" - # Test that multiple calls produce same results (due to fixed seed in function) - result1 = eeg_autocorr_welch(self.EEG, pct_data=100) - result2 = eeg_autocorr_welch(self.EEG, pct_data=100) - - # Results should be identical for same data - np.testing.assert_allclose(result1, result2, atol=1e-10) - - def test_n_points_calculation(self): - """Test n_points calculation logic.""" - # Test case where pnts < srate * 3 - test_EEG = self.EEG.copy() - test_EEG['pnts'] = 200 - test_EEG['srate'] = 100 # srate * 3 = 300, so pnts < srate * 3 - test_EEG['icaact'] = np.random.randn(self.n_components, 200, self.n_trials) - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - self.assertEqual(result.shape, (self.n_components, 100)) - - # Test case where pnts >= srate * 3 - test_EEG['pnts'] = 400 - test_EEG['srate'] = 100 # srate * 3 = 300, so pnts > srate * 3 - test_EEG['icaact'] = np.random.randn(self.n_components, 400, self.n_trials) - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - self.assertEqual(result.shape, (self.n_components, 100)) - - def test_fft_size_calculation(self): - """Test FFT size calculation (power of 2).""" - # The function calculates nfft = 2**(int(np.log2(n_points * 2 - 1)) + 1) - # This should always result in a power of 2 >= 2 * n_points - 1 - - test_EEG = self.EEG.copy() - - # Test with various n_points values - for pnts in [64, 100, 256, 500, 1000]: - with self.subTest(pnts=pnts): - test_EEG['pnts'] = pnts - test_EEG['icaact'] = np.random.randn(self.n_components, pnts, self.n_trials) - - # This should not raise any errors - result = eeg_autocorr_welch(test_EEG, pct_data=100) - self.assertEqual(result.shape, (self.n_components, 100)) - - def test_segment_indexing(self): - """Test the segment indexing logic.""" - # Test with data that allows multiple segments - test_EEG = self.EEG.copy() - test_EEG['pnts'] = 1000 # Large enough for multiple segments - test_EEG['icaact'] = np.random.randn(self.n_components, 1000, self.n_trials) - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - self.assertEqual(result.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_autocorrelation_properties(self): - """Test mathematical properties of autocorrelation.""" - # Create a known signal - test_EEG = self.EEG.copy() - n_comp, n_pnts, n_trials = 1, 256, 1 - test_EEG['icaweights'] = np.random.randn(n_comp, self.n_channels) - test_EEG['pnts'] = n_pnts - test_EEG['trials'] = n_trials - test_EEG['srate'] = 128 # Nice power of 2 - - # Create white noise - icaact = np.random.randn(n_comp, n_pnts, n_trials) - test_EEG['icaact'] = icaact - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - - # For white noise, autocorrelation should decay quickly - self.assertEqual(result.shape, (n_comp, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - # Check that result has reasonable magnitude - self.assertTrue(np.all(np.abs(result) < 10)) # Should be reasonably bounded - - def test_edge_case_single_point(self): - """Test edge case with very small data.""" - test_EEG = self.EEG.copy() - test_EEG['pnts'] = 10 # Very small - test_EEG['icaact'] = np.random.randn(self.n_components, 10, self.n_trials) - - # Should still work without errors - result = eeg_autocorr_welch(test_EEG, pct_data=100) - self.assertEqual(result.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_resampling_consistency(self): - """Test that resampling produces consistent results.""" - # Test with sampling rates that are multiples of 100 - for srate in [100, 200, 400]: - with self.subTest(srate=srate): - test_EEG = self.EEG.copy() - test_EEG['srate'] = srate - - result = eeg_autocorr_welch(test_EEG, pct_data=100) - - # Should always resample to 100 samples - self.assertEqual(result.shape, (self.n_components, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_memory_efficiency(self): - """Test with larger datasets to check memory efficiency.""" - # Create a larger dataset but with single trial to avoid function bugs - large_EEG = { - 'icaweights': np.random.randn(10, 32), # 10 components, 32 channels - 'icaact': np.random.randn(10, 750, 1) * 0.5, # Large data, single trial - 'pnts': 750, - 'trials': 1, - 'srate': 250, - } - - # This should complete without memory errors - result = eeg_autocorr_welch(large_EEG, pct_data=100) - self.assertEqual(result.shape, (10, 100)) - self.assertTrue(np.all(np.isfinite(result))) - - def test_numerical_stability(self): - """Test numerical stability with extreme values.""" - test_EEG = self.EEG.copy() - - # Test with very small values - test_EEG['icaact'] = np.random.randn(*test_EEG['icaact'].shape) * 1e-10 - result_small = eeg_autocorr_welch(test_EEG, pct_data=100) - self.assertTrue(np.all(np.isfinite(result_small))) - - # Test with larger values - test_EEG['icaact'] = np.random.randn(*test_EEG['icaact'].shape) * 100 - result_large = eeg_autocorr_welch(test_EEG, pct_data=100) - self.assertTrue(np.all(np.isfinite(result_large))) - - def test_parity_basic_autocorr_welch(self): - """Test parity with MATLAB for basic autocorrelation using real ICA data.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Load real EEG dataset with ICA - from eegprep.functions.popfunc.pop_loadset import pop_loadset - - test_file = os.path.join(os.path.dirname(__file__), '..', 'data', 'eeglab_data_with_ica_tmp.set') - if not os.path.exists(test_file): - self.skipTest(f"Test file not found: {test_file}") - - EEG = pop_loadset(test_file) - - # Python result - py_result = eeg_autocorr_welch(EEG.copy()) - - # MATLAB result - ml_result = self.eeglab.eeg_autocorr_welch(EEG.copy()) - - # Compare results - self.assertEqual(py_result.shape, ml_result.shape) - np.testing.assert_allclose(py_result, ml_result, rtol=1e-5, atol=1e-8) - - def test_parity_with_real_data_welch(self): - """Test parity with MATLAB using real ICA data. - - Note: Only tests pct_data=100 due to known bug in Python implementation - with pct_data < 100 (index out of bounds in segment selection). - """ - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Load real EEG dataset with ICA - from eegprep.functions.popfunc.pop_loadset import pop_loadset - - test_file = os.path.join(os.path.dirname(__file__), '..', 'data', 'eeglab_data_with_ica_tmp.set') - if not os.path.exists(test_file): - self.skipTest(f"Test file not found: {test_file}") - - EEG = pop_loadset(test_file) - - # Only test with pct_data=100 (pct_data < 100 has a bug in Python implementation) - py_result = eeg_autocorr_welch(EEG.copy(), pct_data=100) - ml_result = self.eeglab.eeg_autocorr_welch(EEG.copy(), 100) - - self.assertEqual(py_result.shape, ml_result.shape) - np.testing.assert_allclose(py_result, ml_result, rtol=1e-5, atol=1e-8) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_eeg_rpsd_parity.py b/tests/test_eeg_rpsd_parity.py deleted file mode 100644 index 16cfa53b..00000000 --- a/tests/test_eeg_rpsd_parity.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -Test parity between Python and MATLAB implementations of eeg_rpsd. - -This test compares the Python implementation against the MATLAB/EEGLAB reference. -Multithreading is disabled for deterministic numerical results. -""" - -# Disable multithreading for deterministic numerical results -import os - -os.environ["OMP_NUM_THREADS"] = "1" -os.environ["MKL_NUM_THREADS"] = "1" -os.environ["NUMEXPR_NUM_THREADS"] = "1" -os.environ["OPENBLAS_NUM_THREADS"] = "1" -os.environ["VECLIB_MAXIMUM_THREADS"] = "1" - -import unittest -import numpy as np -import tempfile -import scipy.io -from eegprep import pop_loadset, pop_saveset, eeg_rpsd -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab - -local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - - -class TestEegRpsdParity(unittest.TestCase): - """Test parity between Python and MATLAB eeg_rpsd implementations.""" - - def setUp(self): - """Set up test fixtures.""" - # Try to get MATLAB engine - try: - self.eeglab = get_eeglab('MAT', auto_file_roundtrip=False) - self.matlab_available = True - except Exception as e: - self.matlab_available = False - self.skipTest(f"MATLAB not available: {e}") - - # Load real EEG dataset with ICA - test_file = os.path.join(local_url, 'eeglab_data_with_ica_tmp.set') - self.EEG = pop_loadset(test_file) - - def test_parity_default_nfreqs(self): - """Test parity with MATLAB using default nfreqs parameter.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - py_result = eeg_rpsd(self.EEG.copy()) - - # MATLAB result - use file roundtrip - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - psdmed = eeg_rpsd(EEG); - save('{temp_file}.mat', 'psdmed'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - ml_result = mat_data['psdmed'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare results - # Max absolute diff: 0.00017, Mismatched: 1/2048 (0.05%) - # Max relative diff: 1.14e-05 - np.testing.assert_allclose( - py_result, - ml_result, - rtol=2e-5, - atol=1e-8, - err_msg="eeg_rpsd results differ beyond tolerance (default nfreqs)", - ) - - def test_parity_custom_nfreqs_100(self): - """Test parity with MATLAB using nfreqs=100.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - py_result = eeg_rpsd(self.EEG.copy(), nfreqs=100) - - # MATLAB result - use file roundtrip - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - psdmed = eeg_rpsd(EEG, 100); - save('{temp_file}.mat', 'psdmed'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - ml_result = mat_data['psdmed'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare results - # Max absolute diff: 0.00017, Mismatched: 1/2048 (0.05%) - # Max relative diff: 1.14e-05 - np.testing.assert_allclose( - py_result, ml_result, rtol=2e-5, atol=1e-8, err_msg="eeg_rpsd results differ beyond tolerance (nfreqs=100)" - ) - - def test_parity_custom_nfreqs_50(self): - """Test parity with MATLAB using nfreqs=50.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result - py_result = eeg_rpsd(self.EEG.copy(), nfreqs=50) - - # MATLAB result - use file roundtrip - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - psdmed = eeg_rpsd(EEG, 50); - save('{temp_file}.mat', 'psdmed'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - ml_result = mat_data['psdmed'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare results - # Max absolute diff: ~0.0002, Mismatched: ~1/1024 (0.1%) - # Max relative diff: ~1.14e-05 - np.testing.assert_allclose( - py_result, ml_result, rtol=2e-5, atol=1e-8, err_msg="eeg_rpsd results differ beyond tolerance (nfreqs=50)" - ) - - def test_parity_with_icl_processing(self): - """Test parity with MATLAB including ICL processing (notch undo + normalization).""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Python result with ICL processing - psd = eeg_rpsd(self.EEG.copy(), 100) - - # Extrapolate or prune as needed (from ICL_feature_extractor lines 50-53) - nfreq = psd.shape[1] - if nfreq < 100: - psd = np.hstack((psd, np.tile(psd[:, -1][:, np.newaxis], (1, 100 - nfreq)))) - - # Undo notch filter (from ICL_feature_extractor lines 55-61) - for linenoise_ind in [50, 60]: - linenoise_around = [linenoise_ind - 1, linenoise_ind + 1] - difference = psd[:, linenoise_around] - psd[:, linenoise_ind][:, np.newaxis] - notch_ind = np.all(difference > 5, axis=1) - if np.any(notch_ind): - psd[notch_ind, linenoise_ind] = np.mean(psd[notch_ind][:, linenoise_around], axis=1) - - # Normalize (from ICL_feature_extractor line 64) - py_result = psd / np.max(np.abs(psd), axis=1)[:, np.newaxis] - - # MATLAB result - use file roundtrip with same processing - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - - % Call eeg_rpsd - psd = eeg_rpsd(EEG, 100); - - % Extrapolate or prune as needed - nfreq = size(psd, 2); - if nfreq < 100 - psd = [psd, repmat(psd(:, end), 1, 100 - nfreq)]; - end - - % Undo notch filter - for linenoise_ind = [50, 60] - linenoise_around = [linenoise_ind - 1, linenoise_ind + 1]; - difference = psd(:, linenoise_around) - repmat(psd(:, linenoise_ind), 1, 2); - notch_ind = all(difference > 5, 2); - if any(notch_ind) - psd(notch_ind, linenoise_ind) = mean(psd(notch_ind, linenoise_around), 2); - end - end - - % Normalize - psd = psd ./ max(abs(psd), [], 2); - - save('{temp_file}.mat', 'psd'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB result - mat_data = scipy.io.loadmat(temp_file + '.mat') - ml_result = mat_data['psd'] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - # Compare results - # Max absolute diff: TBD - # Max relative diff: TBD - np.testing.assert_allclose( - py_result, ml_result, rtol=2e-5, atol=1e-8, err_msg="eeg_rpsd with ICL processing differs beyond tolerance" - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_eeglabcompat.py b/tests/test_eeglabcompat.py index 910cb7f9..f82ca96a 100644 --- a/tests/test_eeglabcompat.py +++ b/tests/test_eeglabcompat.py @@ -20,7 +20,7 @@ pop_eegfiltnew, eeg_checkset as eeglab_eeg_checkset, ) -from eegprep import clean_artifacts, pop_loadset +from eegprep import pop_loadset from eegprep.functions.adminfunc.eeg_checkset import eeg_checkset from eegprep.utils.testing import DebuggableTestCase import eegprep.functions.adminfunc.eeglabcompat as eeglabcompat diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 97060251..aa9228f7 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -54,8 +54,8 @@ def test_bundled_extension_records_match_plugin_inventory() -> None: records = registry.discover() - assert [record.name for record in records] == ["clean_rawdata", "ICLabel", "firfilt", "dipfit", "EEG_BIDS"] - assert [record.status for record in records] == [ExtensionStatus.BUNDLED] * 5 + assert [record.name for record in records] == ["firfilt", "dipfit", "EEG_BIDS"] + assert [record.status for record in records] == [ExtensionStatus.BUNDLED] * 3 assert all(record.spec is not None for record in records) diff --git a/tests/test_gui_main_window.py b/tests/test_gui_main_window.py index c7a529f6..e56de1cc 100644 --- a/tests/test_gui_main_window.py +++ b/tests/test_gui_main_window.py @@ -150,7 +150,6 @@ def test_default_menu_matches_eeglab_top_level_and_hides_legacy_items(self): self.assertIn('(Expand tool choices via "File > Preferences")', tools_labels) self.assertNotIn("Automatic channel rejection", tools_labels) self.assertIn("Reject data using Clean Rawdata and ASR", tools_labels) - self.assertIn("Classify components using ICLabel", tools_labels) self.assertIn("Source localization using DIPFIT", tools_labels) def test_all_menus_mode_reveals_legacy_items_and_hides_expand_prompt(self): @@ -313,9 +312,7 @@ def test_all_menu_actions_are_classified(self): self.assertEqual(action_kind("pop_firws"), "implemented") self.assertEqual(action_kind("pop_firpm"), "implemented") self.assertEqual(action_kind("pop_firma"), "implemented") - self.assertEqual(action_kind("pop_clean_rawdata"), "implemented") self.assertEqual(action_kind("pop_runica"), "implemented") - self.assertEqual(action_kind("pop_iclabel"), "implemented") self.assertEqual(action_kind("pop_icflag"), "implemented") self.assertEqual(action_kind("pop_subcomp"), "implemented") self.assertEqual(action_kind("pop_exportbids"), "implemented") @@ -952,7 +949,6 @@ def test_resave_multiple_datasets_does_not_collapse_selection(self): def test_new_main_window_pop_actions_dispatch_to_real_wrappers(self): newset_actions = { - "pop_clean_rawdata", "pop_eegfilt", "pop_eegfiltnew", "pop_epoch", @@ -982,9 +978,7 @@ def test_new_main_window_pop_actions_dispatch_to_real_wrappers(self): ("pop_firws", "eegprep.plugins.firfilt.pop_firws.pop_firws", "firws"), ("pop_firpm", "eegprep.plugins.firfilt.pop_firpm.pop_firpm", "firpm"), ("pop_firma", "eegprep.plugins.firfilt.pop_firma.pop_firma", "firma"), - ("pop_clean_rawdata", "eegprep.plugins.clean_rawdata.pop_clean_rawdata.pop_clean_rawdata", "cleaned"), ("pop_runica", "eegprep.functions.popfunc.pop_runica.pop_runica", "ica"), - ("pop_iclabel", "eegprep.plugins.ICLabel.pop_iclabel.pop_iclabel", "labeled"), ] for action, patch_target, setname in action_specs: diff --git a/tests/test_gui_pop_clean_rawdata.py b/tests/test_gui_pop_clean_rawdata.py deleted file mode 100644 index 89b829b4..00000000 --- a/tests/test_gui_pop_clean_rawdata.py +++ /dev/null @@ -1,239 +0,0 @@ -import unittest -from unittest import mock - -import numpy as np - -from eegprep.functions.guifunc.spec import controls_by_tag -from eegprep.plugins.clean_rawdata.pop_clean_rawdata import ( - pop_clean_rawdata, - pop_clean_rawdata_dialog_spec, -) -from eegprep.plugins.clean_rawdata.vis_artifacts import vis_artifacts, vis_artifacts_diagnostics - - -def _eeg(*, epoched=False): - return { - "data": np.zeros((2, 20, 2), dtype=np.float32) if epoched else np.zeros((2, 40), dtype=np.float32), - "nbchan": 2, - "pnts": 20 if epoched else 40, - "trials": 2 if epoched else 1, - "srate": 100, - "xmin": 0, - "xmax": 0.19 if epoched else 0.39, - "chanlocs": [{"labels": "Cz"}, {"labels": "Pz"}], - "etc": {}, - } - - -class PopCleanRawdataGuiTests(unittest.TestCase): - def test_gui_dialog_spec_matches_clean_rawdata_sections(self): - spec = pop_clean_rawdata_dialog_spec(_eeg()) - - self.assertEqual(spec.title, "pop_clean_rawdata()") - self.assertEqual(spec.function_name, "pop_clean_rawdata") - self.assertEqual(spec.eeglab_source, "plugins/clean_rawdata/pop_clean_rawdata.m") - self.assertEqual(spec.help_text, "pophelp('pop_clean_rawdata')") - labels = [(control.style, control.string, control.tag) for control in spec.controls] - self.assertIn(("checkbox", "Remove channel drift (data not already high-pass filtered)", "filter"), labels) - self.assertIn(("checkbox", "Process/remove channels", "chanrm"), labels) - self.assertIn( - ("checkbox", "Perform Artifact Subspace Reconstruction bad burst correction/rejection", "asr"), labels - ) - self.assertIn(("checkbox", "Additional removal of bad data periods", "rejwin"), labels) - controls = controls_by_tag(spec) - self.assertEqual(controls["filter"].font_weight, "bold") - self.assertEqual(controls["chanrm"].font_weight, "bold") - self.assertEqual(controls["asr"].font_weight, "bold") - self.assertEqual(controls["rejwin"].font_weight, "bold") - self.assertTrue(controls["vis"].value) - - def test_gui_channel_callbacks_expose_labels(self): - controls = controls_by_tag(pop_clean_rawdata_dialog_spec(_eeg())) - - self.assertEqual(controls["chanuse_button"].callback.params["channels"], ("Cz", "Pz")) - self.assertEqual(controls["chanignore_button"].callback.params["channels"], ("Cz", "Pz")) - self.assertEqual(controls["filter"].callback.name, "toggle_enabled") - self.assertEqual(controls["filter"].callback.params["targets"], ("filterfreqs",)) - - def test_gui_channel_callbacks_accept_numpy_chanlocs(self): - eeg = _eeg() - eeg["chanlocs"] = np.asarray(eeg["chanlocs"], dtype=object) - - controls = controls_by_tag(pop_clean_rawdata_dialog_spec(eeg)) - - self.assertEqual(controls["chanuse_button"].callback.params["channels"], ("Cz", "Pz")) - self.assertEqual(controls["chanignore_button"].callback.params["channels"], ("Cz", "Pz")) - - def test_gui_result_runs_clean_artifacts_and_returns_history(self): - class Renderer: - def run(self, spec, initial_values=None): - return { - "filter": True, - "filterfreqs": "0.25 0.75", - "chanrm": True, - "chanignoreflag": False, - "chanignore": "", - "chanuseflag": False, - "chanuse": "", - "rmflat": True, - "rmflatsec": "5", - "rmcorr": True, - "rmcorrval": "0.8", - "rmnoise": True, - "rmnoiseval": "4", - "asr": True, - "asrstdval": "20", - "distance": False, - "rejwin": True, - "rejwinval1": "-Inf 7", - "rejwinval2": "25", - "asrrej": True, - "vis": False, - } - - eeg = _eeg() - with mock.patch( - "eegprep.plugins.clean_rawdata.pop_clean_rawdata.clean_artifacts", - return_value=(dict(eeg, setname="cleaned"), eeg, eeg, np.zeros(2, dtype=bool)), - ) as clean: - out, com = pop_clean_rawdata(eeg, gui=True, renderer=Renderer(), return_com=True) - - clean.assert_called_once() - self.assertEqual(out["setname"], "cleaned") - self.assertIn("'BurstCriterion', 20", com) - self.assertIn("'BurstRejection', 'on'", com) - - def test_gui_vis_checkbox_opens_rejected_data_browser_when_checked(self): - class Renderer: - def run(self, spec, initial_values=None): - return { - "filter": False, - "filterfreqs": "", - "chanrm": False, - "chanignoreflag": False, - "chanignore": "", - "chanuseflag": False, - "chanuse": "", - "rmflat": False, - "rmflatsec": "5", - "rmcorr": False, - "rmcorrval": "0.8", - "rmnoise": False, - "rmnoiseval": "4", - "asr": False, - "asrstdval": "20", - "distance": False, - "rejwin": False, - "rejwinval1": "-Inf 7", - "rejwinval2": "25", - "asrrej": False, - "vis": True, - } - - eeg = _eeg() - with ( - mock.patch( - "eegprep.plugins.clean_rawdata.pop_clean_rawdata.clean_artifacts", - return_value=( - dict( - eeg, - setname="cleaned", - etc={"clean_sample_mask": np.r_[np.ones(10, dtype=bool), np.zeros(30, dtype=bool)]}, - ), - eeg, - eeg, - np.zeros(2, dtype=bool), - ), - ) as clean, - mock.patch("eegprep.plugins.clean_rawdata.pop_clean_rawdata.vis_artifacts") as artifacts, - ): - out, com = pop_clean_rawdata(eeg, gui=True, renderer=Renderer(), return_com=True) - - clean.assert_called_once() - artifacts.assert_called_once() - shown, original = artifacts.call_args.args - np.testing.assert_array_equal(shown["etc"]["clean_sample_mask"][10:], np.zeros(30, dtype=bool)) - self.assertEqual(original["pnts"], eeg["pnts"]) - np.testing.assert_array_equal(original["data"], eeg["data"]) - self.assertEqual(out["setname"], "cleaned") - self.assertNotIn("_show_vis_artifacts", com) - - def test_vis_artifacts_diagnostics_summarizes_samples_and_channels(self): - old = _eeg() - old["chanlocs"] = np.asarray(old["chanlocs"], dtype=object) - new = dict( - old, - data=old["data"][:, :30], - pnts=30, - etc={ - "clean_sample_mask": np.r_[np.ones(10, dtype=bool), np.zeros(5, dtype=bool), np.ones(25, dtype=bool)], - "clean_channel_mask": np.asarray([True, False]), - }, - ) - - diag = vis_artifacts_diagnostics(new, old) - - self.assertEqual(diag["original_samples"], 40) - self.assertEqual(diag["clean_samples"], 30) - self.assertEqual(diag["rejected_sample_count"], 5) - np.testing.assert_array_equal(diag["rejected_intervals"], [[11, 15]]) - self.assertEqual(diag["removed_channel_indices"], [2]) - self.assertEqual(diag["removed_channel_labels"], ["Pz"]) - self.assertEqual(diag["winrej"].shape, (1, 7)) - - def test_vis_artifacts_can_return_diagnostics_without_opening_browser(self): - old = _eeg() - new = dict( - old, - etc={"clean_sample_mask": np.r_[np.zeros(3, dtype=bool), np.ones(37, dtype=bool)]}, - ) - - diag = vis_artifacts(new, old, show=False) - - np.testing.assert_array_equal(diag["rejected_intervals"], [[1, 3]]) - self.assertEqual(diag["rejected_fraction"], 3 / 40) - - def test_vis_artifacts_diagnostics_infers_original_size_from_masks(self): - clean = _eeg() - clean["data"] = clean["data"][:1, :30] - clean["nbchan"] = 1 - clean["pnts"] = 30 - clean["chanlocs"] = [{"labels": "Cz"}] - clean["etc"] = { - "clean_sample_mask": np.r_[np.ones(10, dtype=bool), np.zeros(5, dtype=bool), np.ones(25, dtype=bool)], - "clean_channel_mask": np.asarray([True, False]), - } - - diag = vis_artifacts_diagnostics(clean) - - self.assertEqual(diag["original_samples"], 40) - self.assertEqual(diag["clean_samples"], 30) - self.assertEqual(diag["original_channels"], 2) - self.assertEqual(diag["clean_channels"], 1) - np.testing.assert_array_equal(diag["rejected_intervals"], [[11, 15]]) - self.assertEqual(diag["removed_channel_indices"], [2]) - - def test_string_channel_lists_use_matlab_cell_history(self): - eeg = _eeg() - with mock.patch( - "eegprep.plugins.clean_rawdata.pop_clean_rawdata.clean_artifacts", - return_value=(dict(eeg, setname="cleaned"), eeg, eeg, np.zeros(2, dtype=bool)), - ): - _out, com = pop_clean_rawdata( - eeg, - gui=False, - Channels=["Cz", "Pz"], - Channels_ignore=["ECG"], - return_com=True, - ) - - self.assertIn("'Channels', {'Cz' 'Pz'}", com) - self.assertIn("'Channels_ignore', {'ECG'}", com) - - def test_epoched_data_raises_clear_error(self): - with self.assertRaisesRegex(ValueError, "continuous"): - pop_clean_rawdata(_eeg(epoched=True), gui=False) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_gui_pop_icflag.py b/tests/test_gui_pop_icflag.py deleted file mode 100644 index 19c8bae0..00000000 --- a/tests/test_gui_pop_icflag.py +++ /dev/null @@ -1,111 +0,0 @@ -import unittest - -import numpy as np - -from eegprep.functions.adminfunc.console import _console_python_command -from eegprep.functions.guifunc.spec import controls_by_tag -from eegprep.plugins.ICLabel.eeg_icflag import eeg_icflag -from eegprep.plugins.ICLabel.pop_icflag import DEFAULT_ICFLAG_THRESHOLDS, pop_icflag, pop_icflag_dialog_spec - - -def _eeg(): - return { - "data": np.zeros((3, 20), dtype=np.float32), - "nbchan": 3, - "pnts": 20, - "trials": 1, - "srate": 100, - "icaweights": np.eye(3), - "icasphere": np.eye(3), - "icawinv": np.eye(3), - "icachansind": np.arange(3), - "reject": {"rejmanual": np.array([1, 0])}, - "etc": { - "ic_classification": { - "ICLabel": { - "classifications": np.array( - [ - [0.70, 0.10, 0.10, 0.03, 0.02, 0.03, 0.02], - [0.02, 0.94, 0.02, 0.01, 0.00, 0.00, 0.01], - [0.05, 0.02, 0.91, 0.01, 0.00, 0.00, 0.01], - ] - ) - } - } - }, - } - - -class PopIcflagGuiTests(unittest.TestCase): - def test_dialog_spec_matches_eeglab_threshold_prompt(self): - spec = pop_icflag_dialog_spec() - controls = controls_by_tag(spec) - - self.assertEqual(spec.title, "Flag components using ICLabel -- pop_icflag()") - self.assertEqual(spec.function_name, "pop_icflag") - self.assertEqual(spec.eeglab_source, "plugins/ICLabel/pop_icflag.m") - self.assertEqual(spec.controls[0].font_weight, "bold") - self.assertEqual(controls["min_1"].value, "0.9") - self.assertEqual(controls["max_1"].value, "1") - self.assertEqual(controls["min_2"].value, "0.9") - self.assertEqual(controls["max_2"].value, "1") - - def test_gui_result_flags_components_and_returns_replayable_history(self): - class Renderer: - def run(self, spec, initial_values=None): - return { - "min_0": "", - "max_0": "", - "min_1": "0.9", - "max_1": "1", - "min_2": "0.9", - "max_2": "1", - "min_3": "", - "max_3": "", - "min_4": "", - "max_4": "", - "min_5": "", - "max_5": "", - "min_6": "", - "max_6": "", - } - - out, com = pop_icflag(_eeg(), gui=True, renderer=Renderer(), return_com=True) - - np.testing.assert_array_equal(out["reject"]["gcompreject"], [0, 1, 1]) - np.testing.assert_array_equal(out["reject"]["rejmanual"], [1, 0]) - self.assertEqual( - _console_python_command(com), - ( - "EEG = pop_icflag(EEG, thresholds=[[None, None], [0.9, 1], " - "[0.9, 1], [None, None], [None, None], [None, None], [None, None]])" - ), - ) - - def test_eeg_icflag_uses_eeglab_open_interval_thresholds(self): - eeg = _eeg() - thresholds = np.array(DEFAULT_ICFLAG_THRESHOLDS) - eeg["etc"]["ic_classification"]["ICLabel"]["classifications"][1, 1] = 0.9 - - out = eeg_icflag(eeg, thresholds) - - np.testing.assert_array_equal(out["reject"]["gcompreject"], [0, 0, 1]) - - def test_missing_iclabel_raises_clear_error(self): - eeg = _eeg() - eeg["etc"] = {} - - with self.assertRaisesRegex(ValueError, "Run pop_iclabel first"): - pop_icflag(eeg, DEFAULT_ICFLAG_THRESHOLDS) - - def test_missing_iclabel_in_dataset_list_raises_clear_error(self): - eeg = _eeg() - missing = _eeg() - missing["etc"] = {} - - with self.assertRaisesRegex(ValueError, "Run pop_iclabel first"): - pop_icflag([eeg, missing], DEFAULT_ICFLAG_THRESHOLDS) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_gui_pop_iclabel.py b/tests/test_gui_pop_iclabel.py deleted file mode 100644 index 4b180fd7..00000000 --- a/tests/test_gui_pop_iclabel.py +++ /dev/null @@ -1,80 +0,0 @@ -import unittest -from unittest import mock - -import numpy as np - -from eegprep.plugins.ICLabel.pop_iclabel import pop_iclabel, pop_iclabel_dialog_spec - - -def _eeg(): - return { - "data": np.zeros((2, 20), dtype=np.float32), - "nbchan": 2, - "pnts": 20, - "trials": 1, - "srate": 100, - "icaweights": np.eye(2), - "icasphere": np.eye(2), - "icawinv": np.eye(2), - "icachansind": np.arange(2), - "etc": {}, - } - - -class PopIclabelGuiTests(unittest.TestCase): - def test_gui_dialog_spec_matches_iclabel_prompt(self): - spec = pop_iclabel_dialog_spec() - - self.assertEqual(spec.title, "ICLabel") - self.assertEqual(spec.function_name, "pop_iclabel") - self.assertEqual(spec.eeglab_source, "plugins/ICLabel/pop_iclabel.m") - self.assertIsNone(spec.help_text) - self.assertFalse(spec.show_help_button) - self.assertEqual( - [(control.style, control.string, control.tag) for control in spec.controls], - [ - ("text", "Select which icversion of ICLabel to use:", None), - ("popupmenu", "Default (recommended)|Lite|Beta", "icversion"), - ], - ) - - def test_gui_result_runs_iclabel_and_returns_history(self): - class Renderer: - def run(self, spec, initial_values=None): - return {"icversion": 1} - - eeg = _eeg() - updated = dict(eeg, etc={"ic_classification": {"ICLabel": {"version": "default"}}}) - with mock.patch("eegprep.plugins.ICLabel.pop_iclabel.iclabel", return_value=updated) as classify: - out, com = pop_iclabel(eeg, gui=True, renderer=Renderer(), return_com=True) - - classify.assert_called_once_with(eeg, algorithm="default", engine=None) - self.assertEqual(out["etc"]["ic_classification"]["ICLabel"]["version"], "default") - self.assertEqual(com, "EEG = pop_iclabel(EEG, 'default');") - - def test_python_engine_rejects_unbundled_lite_and_beta_networks(self): - eeg = _eeg() - - with self.assertRaisesRegex(NotImplementedError, "standalone Python ICLabel only ships the default network"): - pop_iclabel(eeg, "lite") - - def test_matlab_engine_can_request_lite_network(self): - eeg = _eeg() - updated = dict(eeg, etc={"ic_classification": {"ICLabel": {"version": "lite"}}}) - - with mock.patch("eegprep.plugins.ICLabel.pop_iclabel.iclabel", return_value=updated) as classify: - out, com = pop_iclabel(eeg, "lite", engine="matlab", return_com=True) - - classify.assert_called_once_with(eeg, algorithm="lite", engine="matlab") - self.assertEqual(out["etc"]["ic_classification"]["ICLabel"]["version"], "lite") - self.assertEqual(com, "EEG = pop_iclabel(EEG, 'lite');") - - def test_missing_ica_raises_clear_error(self): - eeg = dict(_eeg(), icaweights=np.array([])) - - with self.assertRaisesRegex(ValueError, "requires an ICA decomposition"): - pop_iclabel(eeg, "default") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_gui_pop_prop_extended.py b/tests/test_gui_pop_prop_extended.py deleted file mode 100644 index 7dcd15c3..00000000 --- a/tests/test_gui_pop_prop_extended.py +++ /dev/null @@ -1,277 +0,0 @@ -import matplotlib - -matplotlib.use("Agg") - -from matplotlib.widgets import Button -import matplotlib.pyplot as plt -import numpy as np -import pytest - -from eegprep.functions.guifunc.pophelp import pophelp_text -from eegprep.plugins.ICLabel.pop_prop_extended import pop_prop_extended, pop_prop_extended_dialog_spec -from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops -from tests.fixtures import create_test_eeg_with_ica - - -def _dashboard_eeg(*, include_dipfit: bool = False) -> dict: - eeg = create_test_eeg_with_ica(n_channels=4, n_samples=100, srate=100.0, n_components=4, n_trials=1) - samples = np.linspace(0.0, 1.0, 100) - eeg["data"] = np.vstack([np.sin(2 * np.pi * (index + 1) * samples) for index in range(4)]) - eeg["icaweights"] = np.eye(4) - eeg["icasphere"] = np.eye(4) - eeg["icawinv"] = np.eye(4) - eeg["icaact"] = np.vstack([np.cos(2 * np.pi * (index + 1) * samples) for index in range(4)]) - eeg["icachansind"] = np.arange(4) - eeg["times"] = samples * 1000.0 - eeg["xmin"] = 0.0 - eeg["xmax"] = 1.0 - eeg["event"] = [{"type": "stim", "latency": 25.0, "duration": 0.0}] - eeg["reject"] = {"gcompreject": np.zeros(4, dtype=int)} - eeg["etc"] = { - "ic_classification": { - "ICLabel": { - "classifications": np.array( - [ - [0.70, 0.10, 0.10, 0.03, 0.02, 0.03, 0.02], - [0.02, 0.94, 0.02, 0.01, 0.00, 0.00, 0.01], - [0.05, 0.02, 0.91, 0.01, 0.00, 0.00, 0.01], - [0.80, 0.05, 0.05, 0.02, 0.02, 0.03, 0.03], - ] - ), - "classes": ["Brain", "Muscle", "Eye", "Heart", "Line Noise", "Channel Noise", "Other"], - } - } - } - if include_dipfit: - eeg["dipfit"] = { - "coordformat": "MNI", - "model": [ - {"posxyz": [0, -20, 40], "momxyz": [1, 0, 0], "rv": 0.12, "component": 1}, - { - "posxyz": [[25, 10, 35], [-25, 10, 35]], - "momxyz": [[0, 1, 0], [0, 2, 0]], - "rv": 0.2, - "component": 2, - }, - {"posxyz": [], "momxyz": [], "rv": 1.0, "component": 3}, - {"posxyz": [], "momxyz": [], "rv": 1.0, "component": 4}, - ], - } - return eeg - - -def _axis_by_title(figure, title: str): - return next(axis for axis in figure.axes if axis.get_title() == title) - - -def _event_marker_labels(figure, title: str) -> list[str]: - axis = _axis_by_title(figure, title) - return [text.get_text() for text in axis.texts] - - -def _dashed_marker_x_positions(figure, title: str) -> list[float]: - axis = _axis_by_title(figure, title) - return [float(line.get_xdata()[0]) for line in axis.lines if line.get_linestyle() == "--"] - - -def _reject_button_label(figure) -> str: - return figure.eegprep_dashboard_rejection_buttons["status"].label.get_text() - - -def test_gui_dashboard_creation_has_eeglab_labels_titles_and_activity_browser() -> None: - eeg = _dashboard_eeg() - - figure = pop_prop_extended(eeg, 0, [1, 2], spec_opt="'freqrange', [2 40]", scroll_event=1) - - titles = [axis.get_title() for axis in figure.axes] - assert figure.eegprep_dashboard_data.index == 1 - assert figure._suptitle.get_text() == "IC1 - pop_prop_extended()" - assert "IC1" in titles - assert "ICLabel" in titles - assert "Scrolling IC1 Activity" in titles - assert "Continuous Data" in titles - assert "IC1 Activity Power Spectrum" in titles - assert figure.eegprep_activity_view.state.title == "Scrolling IC1 Activity -- eegplot()" - assert len(figure.eegprep_activity_view.state.events) == 1 - assert _event_marker_labels(figure, "Scrolling IC1 Activity") == ["stim"] - assert _dashed_marker_x_positions(figure, "Scrolling IC1 Activity") == [25.0] - assert set(figure.eegprep_dashboard_navigation) == {"previous", "next"} - plt.close(figure) - - -def test_gui_dashboard_navigation_updates_visible_component() -> None: - eeg = _dashboard_eeg() - figure = pop_prop_extended(eeg, 0, [1, 2], scroll_event=1) - - figure.eegprep_dashboard_navigation["next"]() - - assert figure.eegprep_dashboard_data.index == 2 - assert figure._suptitle.get_text() == "IC2 - pop_prop_extended()" - assert any(axis.get_title() == "Scrolling IC2 Activity" for axis in figure.axes) - plt.close(figure) - - -def test_gui_dashboard_rejection_controls_commit_selected_component_flags() -> None: - eeg = _dashboard_eeg() - figure = pop_prop_extended(eeg, 0, [1, 2], scroll_event=1) - - assert _reject_button_label(figure) == "ACCEPT" - figure.eegprep_dashboard_rejection["toggle"]() - assert _reject_button_label(figure) == "REJECT" - np.testing.assert_array_equal(eeg["reject"]["gcompreject"], [0, 0, 0, 0]) - - figure.eegprep_dashboard_navigation["next"]() - assert figure.eegprep_dashboard_data.index == 2 - assert _reject_button_label(figure) == "ACCEPT" - figure.eegprep_dashboard_rejection["toggle"]() - figure.eegprep_dashboard_navigation["previous"]() - assert _reject_button_label(figure) == "REJECT" - - figure.eegprep_dashboard_rejection["ok"]() - - np.testing.assert_array_equal(eeg["reject"]["gcompreject"], [1, 1, 0, 0]) - plt.close(figure) - - -def test_gui_dashboard_rejection_cancel_discards_pending_flags() -> None: - eeg = _dashboard_eeg() - figure = pop_prop_extended(eeg, 0, 1, scroll_event=1) - - figure.eegprep_dashboard_rejection["toggle"]() - figure.eegprep_dashboard_rejection["cancel"]() - - np.testing.assert_array_equal(eeg["reject"]["gcompreject"], [0, 0, 0, 0]) - plt.close(figure) - - -def test_gui_dashboard_rejection_commit_updates_callback_and_origin_button() -> None: - eeg = _dashboard_eeg() - origin_figure = plt.figure() - origin_button = Button(origin_figure.add_axes((0.1, 0.1, 0.2, 0.2)), "1") - calls = [] - figure = pop_prop_extended( - eeg, - 0, - 1, - winhandle=origin_button, - scroll_event=1, - reject_callback=lambda updated, states: calls.append((updated, states)), - ) - - figure.eegprep_dashboard_rejection["toggle"]() - figure.eegprep_dashboard_rejection["ok"]() - - assert len(calls) == 1 - assert calls[0][0] is eeg - assert calls[0][1] == {1: True} - np.testing.assert_array_equal(eeg["reject"]["gcompreject"], [1, 0, 0, 0]) - assert origin_button.ax.get_facecolor() == pytest.approx((1.0, 0.6, 0.6, 1.0)) - plt.close(figure) - plt.close(origin_figure) - - -def test_gui_dashboard_activity_browser_honors_event_display_option() -> None: - eeg = _dashboard_eeg() - - with_events = pop_prop_extended(eeg, 0, 1, scroll_event=1) - without_events = pop_prop_extended(eeg, 0, 1, scroll_event=0) - - assert len(with_events.eegprep_activity_view.state.events) == 1 - assert without_events.eegprep_activity_view.state.events == [] - assert _event_marker_labels(with_events, "Scrolling IC1 Activity") == ["stim"] - assert _event_marker_labels(without_events, "Scrolling IC1 Activity") == [] - plt.close(with_events) - plt.close(without_events) - - -def test_gui_dashboard_ignores_event_dict_without_latency() -> None: - eeg = _dashboard_eeg() - eeg["event"] = {"type": ["stim"]} - - figure = pop_prop_extended(eeg, 0, 1, scroll_event=1) - - assert _event_marker_labels(figure, "Scrolling IC1 Activity") == [] - plt.close(figure) - - -def test_gui_dashboard_epoched_static_events_use_flattened_event_latencies() -> None: - eeg = _dashboard_eeg() - eeg["data"] = np.repeat(eeg["data"][:, :, np.newaxis], 2, axis=2) - eeg["icaact"] = np.repeat(eeg["icaact"][:, :, np.newaxis], 2, axis=2) - eeg["trials"] = 2 - eeg["event"] = [ - {"type": "first", "latency": 25.0, "duration": 0.0, "epoch": 1}, - {"type": "second", "latency": 125.0, "duration": 0.0, "epoch": 2}, - ] - eeg["epoch"] = [ - {"event": [0], "eventtype": ["first"], "eventlatency": [0.0], "eventduration": [0.0]}, - {"event": [1], "eventtype": ["second"], "eventlatency": [0.0], "eventduration": [0.0]}, - ] - - figure = pop_prop_extended(eeg, 0, 1, scroll_event=1) - - marker_labels = _event_marker_labels(figure, "Scrolling IC1 Activity") - assert "first" in marker_labels - assert "second" in marker_labels - assert "epoch 1" in marker_labels - assert "epoch 2" in marker_labels - assert sorted(round(position) for position in _dashed_marker_x_positions(figure, "Scrolling IC1 Activity")) == [ - 25, - 125, - ] - assert [event.type for event in figure.eegprep_activity_view.state.events] == ["first", "second"] - plt.close(figure) - - -def test_gui_dashboard_renders_dipfit_three_view_surface() -> None: - eeg = _dashboard_eeg(include_dipfit=True) - - figure = pop_prop_extended(eeg, 0, 1, scroll_event=1) - - titles = [axis.get_title() for axis in figure.axes] - all_text = [text.get_text() for axis in figure.axes for text in axis.texts] - dipfit_axis = _axis_by_title(figure, "Dipole Position") - assert figure.eegprep_dashboard_data.dipfit is not None - assert "Dipole Position" in titles - assert any("RV: 12.0%" in text for text in all_text) - assert dipfit_axis.images - plt.close(figure) - - -def test_gui_dashboard_navigation_updates_dipfit_surface() -> None: - eeg = _dashboard_eeg(include_dipfit=True) - figure = pop_prop_extended(eeg, 0, [1, 2], scroll_event=1) - - figure.eegprep_dashboard_navigation["next"]() - - assert figure.eegprep_dashboard_data.index == 2 - assert figure.eegprep_dashboard_data.dipfit is not None - assert figure.eegprep_dashboard_data.dipfit.dmr == 2.0 - assert any("RV: 20.0%" in text.get_text() for axis in figure.axes for text in axis.texts) - plt.close(figure) - - -def test_pop_viewprops_component_mode_opens_extended_dashboard_when_classifier_is_available() -> None: - eeg = _dashboard_eeg() - - figures = pop_viewprops(eeg, 0, [1, 2], plot=True, show_activity=False) - - assert len(figures) == 1 - assert figures[0].eegprep_dashboard_data.index == 1 - assert figures[0].eegprep_dashboard_data.classifier.name == "ICLabel" - figures[0].eegprep_dashboard_navigation["next"]() - assert figures[0].eegprep_dashboard_data.index == 2 - plt.close(figures[0]) - - -def test_pop_prop_extended_dialog_and_help_are_packaged() -> None: - eeg = _dashboard_eeg() - - spec = pop_prop_extended_dialog_spec(eeg, 0) - help_text, source_path = pophelp_text("pop_prop_extended") - - assert spec.title == "Component properties - pop_prop_extended()" - assert spec.eeglab_source == "plugins/ICLabel/viewprops/pop_prop_extended.m" - assert spec.help_text == "pophelp('pop_prop_extended')" - assert "POP_PROP_EXTENDED" in help_text - assert source_path == "eegprep/resources/help/pop_prop_extended.md" diff --git a/tests/test_gui_rejection_dialogs.py b/tests/test_gui_rejection_dialogs.py deleted file mode 100644 index f9726444..00000000 --- a/tests/test_gui_rejection_dialogs.py +++ /dev/null @@ -1,356 +0,0 @@ -import unittest -from unittest import mock - -import numpy as np - -from eegprep.functions.adminfunc.console import _console_python_command -from eegprep.functions.guifunc.menu_actions import MenuActionDispatcher, action_kind -from eegprep.functions.guifunc.session import EEGPrepSession -from eegprep.functions.guifunc.spec import controls_by_tag -from eegprep.functions.popfunc.pop_autorej import pop_autorej_dialog_spec -from eegprep.functions.popfunc.pop_eegthresh import pop_eegthresh, pop_eegthresh_dialog_spec -from eegprep.functions.popfunc.pop_jointprob import pop_jointprob, pop_jointprob_dialog_spec -from eegprep.functions.popfunc.pop_rejchan import pop_rejchan_dialog_spec -from eegprep.functions.popfunc.pop_rejcont import pop_rejcont_dialog_spec -from eegprep.functions.popfunc.pop_rejkurt import pop_rejkurt, pop_rejkurt_dialog_spec -from eegprep.functions.popfunc.pop_rejmenu import pop_rejmenu_dialog_spec -from eegprep.functions.popfunc.pop_rejspec import pop_rejspec_dialog_spec -from eegprep.functions.popfunc.pop_rejtrend import pop_rejtrend_dialog_spec -from eegprep.functions.popfunc.pop_selectcomps import pop_selectcomps_dialog_spec -from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops, pop_viewprops_dialog_spec -from tests.fixtures import create_test_eeg - - -def _epoched_ica_eeg(): - eeg = create_test_eeg(n_channels=3, n_samples=40, n_trials=3, srate=100) - eeg["data"] = np.zeros((3, 40, 3)) - eeg["icaweights"] = np.eye(3) - eeg["icasphere"] = np.eye(3) - eeg["icawinv"] = np.eye(3) - eeg["icachansind"] = np.arange(3) - eeg["reject"] = { - "gcompreject": np.array([0, 1, 0]), - "rejthresh": np.array([0, 1, 0]), - "rejthreshE": np.zeros((3, 3), dtype=bool), - } - return eeg - - -class RejectionDialogTests(unittest.TestCase): - def test_dialog_specs_keep_eeglab_source_and_key_defaults(self): - eeg = _epoched_ica_eeg() - specs = [ - pop_eegthresh_dialog_spec(eeg, 1), - pop_jointprob_dialog_spec(eeg, 1), - pop_rejkurt_dialog_spec(eeg, 0), - pop_rejtrend_dialog_spec(eeg, 1), - pop_rejspec_dialog_spec(eeg, 0), - pop_rejchan_dialog_spec(eeg), - pop_rejcont_dialog_spec(eeg), - pop_autorej_dialog_spec(eeg), - pop_rejmenu_dialog_spec(eeg, 1), - pop_selectcomps_dialog_spec(eeg), - pop_viewprops_dialog_spec(eeg, 0), - ] - - for spec in specs: - self.assertIn("pop_", spec.function_name) - self.assertTrue(spec.eeglab_source.endswith(".m")) - self.assertIsNotNone(spec.size) - - self.assertEqual(controls_by_tag(specs[0])["elecrange"].value, "1:3") - self.assertEqual(controls_by_tag(specs[1])["vistype"].value, 2) - self.assertTrue(controls_by_tag(specs[1])["superpose"].value) - self.assertEqual(controls_by_tag(specs[2])["vistype"].value, 2) - self.assertTrue(controls_by_tag(specs[2])["superpose"].value) - self.assertTrue(controls_by_tag(specs[7])["eegplot"].enabled) - self.assertTrue(controls_by_tag(specs[7])["eegplot"].value) - self.assertTrue(controls_by_tag(specs[8])["scrollmanual"].enabled) - self.assertEqual(controls_by_tag(specs[8])["scrollmanual"].callback.name, "open_rejection_browser") - self.assertTrue(controls_by_tag(specs[10])["scroll_event"].enabled) - - def test_viewprops_component_dialog_includes_classifier_dropdown(self): - eeg = _epoched_ica_eeg() - eeg["etc"] = {"ic_classification": {"Other": {}, "ICLabel": {}}} - controls = controls_by_tag(pop_viewprops_dialog_spec(eeg, 0)) - - self.assertEqual(controls["classifier_name"].string, "Other|ICLabel") - self.assertEqual(controls["classifier_name"].value, 2) - - def test_component_probability_dialog_defaults_match_eeglab(self): - eeg = _epoched_ica_eeg() - - self.assertEqual(controls_by_tag(pop_jointprob_dialog_spec(eeg, 0))["locthresh"].value, "5") - self.assertEqual(controls_by_tag(pop_jointprob_dialog_spec(eeg, 0))["globthresh"].value, "5") - self.assertEqual(controls_by_tag(pop_rejkurt_dialog_spec(eeg, 0))["locthresh"].value, "5") - self.assertEqual(controls_by_tag(pop_rejkurt_dialog_spec(eeg, 0))["globthresh"].value, "5") - - def test_rejection_menu_actions_are_implemented_or_browser_excluded(self): - implemented = [ - "eeg_rejsuperpose:data_to_ica", - "pop_autorej", - "pop_eegthresh:data", - "pop_jointprob:ica", - "pop_rejchan", - "pop_rejcont", - "pop_rejepoch:data", - "pop_rejkurt:ica", - "pop_rejmenu:data", - "pop_rejspec:ica", - "pop_rejtrend:data", - "pop_selectcomps", - "pop_viewprops:channels", - "pop_viewprops:components", - ] - for action in implemented: - self.assertEqual(action_kind(action), "implemented") - self.assertEqual(action_kind("pop_eegplot:reject_data"), "implemented") - - def test_gui_command_is_valid_python_with_keywords(self): - class Renderer: - def run(self, spec, initial_values=None): - return { - "elecrange": "1", - "negthresh": "-5", - "posthresh": "5", - "starttime": "0", - "endtime": "0.39", - "superpose": False, - "reject": False, - } - - eeg = _epoched_ica_eeg() - eeg["data"][0, 2, 1] = 10 - _out, com = pop_eegthresh(eeg, gui=True, renderer=Renderer(), return_com=True, show=False) - - self.assertEqual( - _console_python_command(com), - "EEG = pop_eegthresh(EEG, icacomp=1, elecrange=[1], negthresh=[-5], " - "posthresh=[5], starttime=[0], endtime=[0.39], superpose=0, reject=0)", - ) - - def test_probability_dialog_commands_include_visualization_mode(self): - class Renderer: - def run(self, spec, initial_values=None): - return { - "elecrange": "1", - "locthresh": "4", - "globthresh": "6", - "vistype": 2, - "superpose": True, - "reject": False, - } - - eeg = _epoched_ica_eeg() - _joint_out, joint_com = pop_jointprob(eeg, gui=True, renderer=Renderer(), return_com=True, show=False) - _kurt_out, kurt_com = pop_rejkurt(eeg, gui=True, renderer=Renderer(), return_com=True, show=False) - - self.assertEqual( - _console_python_command(joint_com), - "EEG = pop_jointprob(EEG, icacomp=1, elecrange=[1], locthresh=[4], " - "globthresh=[6], superpose=1, reject=0, vistype=1, topcommand=[], plotflag=0)", - ) - self.assertEqual( - _console_python_command(kurt_com), - "EEG = pop_rejkurt(EEG, icacomp=1, elecrange=[1], locthresh=[4], " - "globthresh=[6], superpose=1, reject=0, vistype=1, topcommand=[], plotflag=0)", - ) - - def test_viewprops_gui_records_options_and_classifier(self): - class Renderer: - def run(self, spec, initial_values=None): - return { - "chanorcomp": "1", - "spec_opt": "'freqrange', [2 40]", - "erp_opt": "'limits', [-100 500]", - "scroll_event": False, - "classifier_name": 2, - } - - eeg = _epoched_ica_eeg() - eeg["etc"] = {"ic_classification": {"Other": {}, "ICLabel": {}}} - _figures, com = pop_viewprops(eeg, 0, gui=True, renderer=Renderer(), plot=False, return_com=True) - - self.assertEqual( - _console_python_command(com), - "pop_viewprops(EEG, typecomp=0, chanorcomp=[1], spec_opt=\"'freqrange', [2 40]\", " - "erp_opt=\"'limits', [-100 500]\", scroll_event=0, classifier_name='ICLabel')", - ) - - def test_viewprops_dispatch_records_history_without_replacing_dataset(self): - session = EEGPrepSession() - eeg = _epoched_ica_eeg() - session.store_current(eeg, new=True) - dispatcher = MenuActionDispatcher(session) - - with mock.patch( - "eegprep.plugins.ICLabel.pop_viewprops.pop_viewprops", - return_value=(["figure"], "pop_viewprops(EEG, 0, [1], [], [], 1, '')"), - ) as viewprops: - dispatcher.dispatch("pop_viewprops:components") - - viewprops.assert_called_once() - self.assertIs(viewprops.call_args.args[0], session.EEG) - self.assertEqual(viewprops.call_args.kwargs["typecomp"], 0) - self.assertTrue(callable(viewprops.call_args.kwargs["reject_callback"])) - self.assertTrue(viewprops.call_args.kwargs["return_com"]) - self.assertEqual(session.ALLCOM[-1], "pop_viewprops(EEG, 0, [1], [], [], 1, '')") - - def test_viewprops_reject_callback_stores_dashboard_component_marks(self): - session = EEGPrepSession() - eeg = _epoched_ica_eeg() - eeg["reject"]["gcompreject"] = np.zeros(3, dtype=int) - session.store_current(eeg, new=True) - dispatcher = MenuActionDispatcher(session) - captured = {} - - def fake_viewprops(selection, **kwargs): - captured["selection"] = selection - captured["reject_callback"] = kwargs["reject_callback"] - return ["figure"], "pop_viewprops(EEG, 0, [1], [], [], 1, '')" - - with mock.patch("eegprep.plugins.ICLabel.pop_viewprops.pop_viewprops", side_effect=fake_viewprops): - dispatcher.dispatch("pop_viewprops:components") - - captured["selection"]["reject"]["gcompreject"][1] = 1 - captured["reject_callback"](captured["selection"], {2: True}) - - np.testing.assert_array_equal(session.EEG["reject"]["gcompreject"], [0, 1, 0]) - np.testing.assert_array_equal(session.ALLEEG[0]["reject"]["gcompreject"], [0, 1, 0]) - - def test_reject_marked_epochs_uses_rejglobal_for_ica_menu(self): - session = EEGPrepSession() - eeg = _epoched_ica_eeg() - eeg["reject"]["rejglobal"] = np.array([False, True, False]) - eeg["reject"]["icarejglobal"] = np.array([False, False, False]) - session.store_current(eeg, new=True) - dispatcher = MenuActionDispatcher(session) - - with mock.patch( - "eegprep.functions.popfunc.pop_rejepoch.pop_rejepoch", - return_value=(eeg, "EEG = pop_rejepoch(EEG, [2], 1);"), - ) as rejepoch: - dispatcher.dispatch("pop_rejepoch:ica") - - rejepoch.assert_called_once() - np.testing.assert_array_equal(rejepoch.call_args.args[1], np.array([False, True, False])) - - def test_eegplot_browser_accept_callback_is_session_safe(self): - session = EEGPrepSession() - eeg = create_test_eeg(n_channels=1, n_samples=10, n_trials=1, srate=10) - session.store_current(eeg, new=True) - dispatcher = MenuActionDispatcher(session) - captured = {} - - def fake_pop_eegplot(selection, **kwargs): - captured["selection"] = selection - captured["callback"] = kwargs["command_callback"] - return "window" - - with mock.patch("eegprep.functions.popfunc.pop_eegplot.pop_eegplot", side_effect=fake_pop_eegplot): - dispatcher.dispatch("pop_eegplot:reject_data") - - self.assertEqual(session.ALLCOM, []) - self.assertEqual(captured["selection"]["setname"], "test_dataset") - updated = dict(eeg, setname="accepted") - captured["callback"](updated, "pop_eegplot(EEG, 1, 0, 1)") - - self.assertEqual(session.EEG["setname"], "accepted") - self.assertEqual(session.ALLCOM[-1], "pop_eegplot(EEG, 1, 0, 1)") - self.assertEqual(session.LASTCOM, "pop_eegplot(EEG, 1, 0, 1)") - - def test_rejcont_dispatch_defers_history_until_browser_accept(self): - session = EEGPrepSession() - eeg = create_test_eeg(n_channels=1, n_samples=10, n_trials=1, srate=10) - session.store_current(eeg, new=True) - dispatcher = MenuActionDispatcher(session) - captured = {} - command = "EEG = pop_rejcont(EEG, 'eegplot', 'on');" - - def fake_pop_rejcont(selection, **kwargs): - captured["selection"] = selection - captured["callback"] = kwargs["command_callback"] - return selection, command - - with mock.patch("eegprep.functions.popfunc.pop_rejcont.pop_rejcont", side_effect=fake_pop_rejcont): - dispatcher.dispatch("pop_rejcont") - - self.assertEqual(session.ALLCOM, []) - self.assertEqual(len(session.ALLEEG), 1) - self.assertEqual(captured["selection"]["setname"], "test_dataset") - - accepted = dict(eeg) - accepted["data"] = np.asarray(eeg["data"])[:, :5] - accepted["pnts"] = 5 - captured["callback"](accepted, command) - - self.assertEqual(session.CURRENTSET, [2]) - self.assertEqual(session.EEG["pnts"], 5) - self.assertEqual(len(session.ALLEEG), 2) - self.assertEqual(session.ALLCOM, [command]) - - def test_rejection_dispatch_stores_multi_dataset_results_without_browser_callback(self): - session = EEGPrepSession() - first = _epoched_ica_eeg() - first["setname"] = "first" - second = _epoched_ica_eeg() - second["setname"] = "second" - session.store_current(first, new=True) - session.store_current(second, new=True) - session.retrieve([1, 2]) - dispatcher = MenuActionDispatcher(session) - captured = {} - command = "EEG = pop_jointprob(EEG, 1, [1], 4, 4, 0, 1, 1);" - - def fake_pop_jointprob(selection, icacomp, **kwargs): - captured["selection"] = selection - captured["icacomp"] = icacomp - captured["kwargs"] = kwargs - output = [dict(item, setname=f"{item['setname']}-marked") for item in selection] - return output, command - - with mock.patch("eegprep.functions.popfunc.pop_jointprob.pop_jointprob", side_effect=fake_pop_jointprob): - dispatcher.dispatch("pop_jointprob:data") - - self.assertEqual([item["setname"] for item in captured["selection"]], ["first", "second"]) - self.assertEqual(captured["icacomp"], 1) - self.assertTrue(captured["kwargs"]["return_com"]) - self.assertNotIn("command_callback", captured["kwargs"]) - self.assertEqual(session.CURRENTSET, [1, 2]) - self.assertEqual([item["setname"] for item in session.ALLEEG], ["first-marked", "second-marked"]) - self.assertEqual(session.ALLCOM, [command]) - - def test_rejection_browser_accept_callback_creates_dataset_without_duplicate_history(self): - session = EEGPrepSession() - eeg = _epoched_ica_eeg() - session.store_current(eeg, new=True) - dispatcher = MenuActionDispatcher(session) - captured = {} - command = "EEG = pop_jointprob(EEG, 1, [1], 4, 4, 0, 1, 1);" - - def fake_pop_jointprob(selection, icacomp, **kwargs): - captured["selection"] = selection - captured["icacomp"] = icacomp - captured["callback"] = kwargs["command_callback"] - return selection, command - - with mock.patch("eegprep.functions.popfunc.pop_jointprob.pop_jointprob", side_effect=fake_pop_jointprob): - dispatcher.dispatch("pop_jointprob:data") - - self.assertEqual(session.ALLCOM, [command]) - self.assertEqual(captured["selection"]["setname"], "test_dataset") - self.assertEqual(captured["icacomp"], 1) - - accepted = dict(eeg) - accepted["data"] = np.asarray(eeg["data"])[:, :, :2] - accepted["trials"] = 2 - captured["callback"](accepted, command) - - self.assertEqual(session.CURRENTSET, [2]) - self.assertEqual(session.EEG["trials"], 2) - self.assertEqual(len(session.ALLEEG), 2) - self.assertEqual(session.ALLCOM, [command]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_iclabel.py b/tests/test_iclabel.py deleted file mode 100644 index e538291a..00000000 --- a/tests/test_iclabel.py +++ /dev/null @@ -1,71 +0,0 @@ -import os -import unittest -import numpy as np -from eegprep import ICL_feature_extractor, iclabel, pop_loadset -from eegprep.utils.testing import has_optional_dependency - -# where the test resources -local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - - -@unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") -class TestICLabelEngines(unittest.TestCase): - def setUp(self): - self.EEG = pop_loadset(os.path.join(local_url, 'eeglab_data_with_ica_tmp.set')) - - def test_basic(self): - if not has_optional_dependency('torch'): - self.skipTest("PyTorch is not installed; install eegprep[torch] to run ICLabel parity") - - features_python = ICL_feature_extractor(self.EEG, True) - print(f"\n{'=' * 60}") - print("FEATURE EXTRACTION COMPARISON") - print(f"{'=' * 60}") - print(f"Python features[0] (topo) shape: {features_python[0].shape}, dtype: {features_python[0].dtype}") - print(f"Python features[1] (psd) shape: {features_python[1].shape}, dtype: {features_python[1].dtype}") - print(f"Python features[2] (autocorr) shape: {features_python[2].shape}, dtype: {features_python[2].dtype}") - print(f"Python topo max: {np.max(features_python[0]):.6f}, min: {np.min(features_python[0]):.6f}") - print(f"Python psd max: {np.max(features_python[1]):.6f}, min: {np.min(features_python[1]):.6f}") - print(f"Python autocorr max: {np.max(features_python[2]):.6f}, min: {np.min(features_python[2]):.6f}") - print(f"{'=' * 60}\n") - - EEG_python = iclabel(self.EEG, algorithm='default', engine=None) - EEG_matlab = iclabel(self.EEG, algorithm='default', engine='matlab') - - res1 = EEG_python['etc']['ic_classification']['ICLabel']['classifications'].flatten() - res2 = EEG_matlab['etc']['ic_classification']['ICLabel']['classifications'].flatten() - - # Diagnostic output - print(f"\n{'=' * 60}") - print("DIAGNOSTIC OUTPUT") - print(f"{'=' * 60}") - print(f"Python result dtype: {res1.dtype}") - print(f"MATLAB result dtype: {res2.dtype}") - print(f"Python result shape: {res1.shape}") - print(f"MATLAB result shape: {res2.shape}") - print(f"\nMax absolute difference: {np.max(np.abs(res1 - res2)):.2e}") - print(f"Mean absolute difference: {np.mean(np.abs(res1 - res2)):.2e}") - print(f"Max relative difference: {np.max(np.abs(res1 - res2) / (np.abs(res2) + 1e-10)):.2e}") - print(f"Mean relative difference: {np.mean(np.abs(res1 - res2) / (np.abs(res2) + 1e-10)):.2e}") - print(f"\nPython results (first 20 values):\n{res1[:20]}") - print(f"\nMATLAB results (first 20 values):\n{res2[:20]}") - print(f"\nDifferences (first 20 values):\n{(res1 - res2)[:20]}") - print(f"\nRelative differences (first 20 values):\n{((res1 - res2) / (res2 + 1e-10))[:20]}") - - # Count how many values exceed tolerances - abs_diffs = np.abs(res1 - res2) - rel_diffs = np.abs(res1 - res2) / (np.abs(res2) + 1e-10) - exceeds_abs = abs_diffs > 1e-8 - exceeds_rel = rel_diffs > 1e-5 - exceeds_both = exceeds_abs & exceeds_rel - print(f"\nValues exceeding absolute tolerance (1e-8): {np.sum(exceeds_abs)}/{len(res1)}") - print(f"Values exceeding relative tolerance (1e-5): {np.sum(exceeds_rel)}/{len(res1)}") - print(f"Values exceeding BOTH tolerances: {np.sum(exceeds_both)}/{len(res1)}") - print(f"{'=' * 60}\n") - - # Max abs diff: 3.37e-06, Max rel diff: 4.25e-05 - self.assertTrue(np.allclose(res1, res2, rtol=1e-4, atol=1e-5), 'ICLabel results differ beyond tolerance') - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_iclabel_features.py b/tests/test_iclabel_features.py deleted file mode 100644 index 757d0257..00000000 --- a/tests/test_iclabel_features.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -Test to compare ICLabel features between Python and MATLAB implementations. -Tests both float32 and float64 precision. -""" - -import os -import unittest -import numpy as np -from eegprep import pop_loadset, ICL_feature_extractor -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab -import tempfile -import scipy.io - -local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - - -@unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") -class TestICLabelFeatureComparison(unittest.TestCase): - def setUp(self): - self.EEG = pop_loadset(os.path.join(local_url, 'eeglab_data_with_ica_tmp.set')) - - def test_feature_comparison_float32(self): - """Compare Python vs MATLAB features in float32.""" - print(f"\n{'=' * 70}") - print("FEATURE COMPARISON: FLOAT32 (Default)") - print(f"{'=' * 70}") - - # Extract Python features (float32) - features_py = ICL_feature_extractor(self.EEG, True) - - # Extract MATLAB features using direct MATLAB call - eeglab = get_eeglab('MAT', auto_file_roundtrip=False) - - # Save EEG to temp file for MATLAB - from eegprep import pop_saveset - - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - # Call MATLAB to extract features and save them - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, true); - save('{temp_file}.mat', 'features'); - """ - eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB features - mat_data = scipy.io.loadmat(temp_file + '.mat') - features_mat = [mat_data['features'][0, 0], mat_data['features'][0, 1], mat_data['features'][0, 2]] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - self._compare_features(features_py, features_mat, "FLOAT32") - - def test_feature_comparison_float64(self): - """Compare Python vs MATLAB features in float64.""" - print(f"\n{'=' * 70}") - print("FEATURE COMPARISON: FLOAT64 (Modified)") - print(f"{'=' * 70}") - - # Extract Python features and convert to float64 - features_py32 = ICL_feature_extractor(self.EEG, True) - features_py = [ - features_py32[0].astype(np.float64), - features_py32[1].astype(np.float64), - features_py32[2].astype(np.float64), - ] - - # Extract MATLAB features - eeglab = get_eeglab('MAT', auto_file_roundtrip=False) - - # Save EEG to temp file - from eegprep import pop_saveset - - temp_file = tempfile.mktemp(suffix='.set') - pop_saveset(self.EEG, temp_file) - - # Call MATLAB to extract features and save them - matlab_code = f""" - EEG = pop_loadset('{temp_file}'); - features = ICL_feature_extractor(EEG, true); - save('{temp_file}.mat', 'features'); - """ - eeglab.eval(matlab_code, nargout=0) - - # Load and convert to float64 - mat_data = scipy.io.loadmat(temp_file + '.mat') - features_mat = [ - mat_data['features'][0, 0].astype(np.float64), - mat_data['features'][0, 1].astype(np.float64), - mat_data['features'][0, 2].astype(np.float64), - ] - - # Clean up - os.remove(temp_file) - os.remove(temp_file + '.mat') - if os.path.exists(temp_file.replace('.set', '.fdt')): - os.remove(temp_file.replace('.set', '.fdt')) - - self._compare_features(features_py, features_mat, "FLOAT64") - - def _compare_features(self, features_py, features_mat, precision): - """Helper to compare and report feature differences.""" - feature_names = ['Topo', 'PSD', 'Autocorr'] - - for i, name in enumerate(feature_names): - py_feat = features_py[i] - mat_feat = features_mat[i] - - print(f"\n{'-' * 70}") - print(f"{name} Feature ({precision})") - print(f"{'-' * 70}") - - # Shape and dtype - print(f"Python shape: {py_feat.shape}, dtype: {py_feat.dtype}") - print(f"MATLAB shape: {mat_feat.shape}, dtype: {mat_feat.dtype}") - - # Min/Max values - print(f"\nPython - min: {np.min(py_feat):+.10f}, max: {np.max(py_feat):+.10f}") - print(f"MATLAB - min: {np.min(mat_feat):+.10f}, max: {np.max(mat_feat):+.10f}") - - # Explain min/max - print("\nMin/Max Explanation:") - print(" - Features are scaled by 0.99 in ICL_feature_extractor") - print(" - Expected range: [-0.99, +0.99] for topo, [0, 0.99] for others") - if np.min(py_feat) < -0.99 or np.max(py_feat) > 0.99: - print(" ⚠️ Python feature EXCEEDS expected range!") - if np.min(mat_feat) < -0.99 or np.max(mat_feat) > 0.99: - print(" ⚠️ MATLAB feature EXCEEDS expected range!") - - # Statistics - print(f"\nPython - mean: {np.mean(py_feat):+.10f}, std: {np.std(py_feat):.10f}") - print(f"MATLAB - mean: {np.mean(mat_feat):+.10f}, std: {np.std(mat_feat):.10f}") - - # Differences - if py_feat.shape == mat_feat.shape: - diff = py_feat - mat_feat - abs_diff = np.abs(diff) - rel_diff = np.abs(diff) / (np.abs(mat_feat) + 1e-10) - - print("\nDifference Statistics:") - print(f" Max absolute diff: {np.max(abs_diff):.2e}") - print(f" Mean absolute diff: {np.mean(abs_diff):.2e}") - print(f" Max relative diff: {np.max(rel_diff):.2e}") - print(f" Mean relative diff: {np.mean(rel_diff):.2e}") - - # Count values exceeding tolerances - exceeds_abs_1e8 = np.sum(abs_diff > 1e-8) - exceeds_abs_1e6 = np.sum(abs_diff > 1e-6) - exceeds_rel_1e5 = np.sum(rel_diff > 1e-5) - total = py_feat.size - - print("\nValues exceeding tolerances:") - print(f" |diff| > 1e-8: {exceeds_abs_1e8:6d}/{total} ({100 * exceeds_abs_1e8 / total:.1f}%)") - print(f" |diff| > 1e-6: {exceeds_abs_1e6:6d}/{total} ({100 * exceeds_abs_1e6 / total:.1f}%)") - print(f" rel diff > 1e-5: {exceeds_rel_1e5:6d}/{total} ({100 * exceeds_rel_1e5 / total:.1f}%)") - - # Are they close? - is_close_1e5 = np.allclose(py_feat, mat_feat, rtol=1e-5, atol=1e-8) - is_close_1e4 = np.allclose(py_feat, mat_feat, rtol=1e-4, atol=1e-6) - is_close_1e3 = np.allclose(py_feat, mat_feat, rtol=1e-3, atol=1e-5) - - print("\nallclose() results:") - print(f" rtol=1e-5, atol=1e-8: {is_close_1e5}") - print(f" rtol=1e-4, atol=1e-6: {is_close_1e4}") - print(f" rtol=1e-3, atol=1e-5: {is_close_1e3}") - else: - print("\n⚠️ Shape mismatch! Cannot compare values.") - - print(f"\n{'=' * 70}\n") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_iclabel_statistics.py b/tests/test_iclabel_statistics.py deleted file mode 100644 index e84b78a4..00000000 --- a/tests/test_iclabel_statistics.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import annotations - -import copy - -import numpy as np -import pytest - -from eegprep import pop_loadset -from eegprep.functions.guifunc.pophelp import pophelp_text -from eegprep.plugins.ICLabel.eeg_icalabelstat import eeg_icalabelstat -from eegprep.plugins.ICLabel.iclabel import iclabel -from eegprep.plugins.ICLabel.pop_icflag import ICLABEL_CLASSES -from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops - - -def _classified_eeg() -> dict: - return { - "setname": "classified", - "data": np.zeros((4, 100)), - "nbchan": 4, - "pnts": 100, - "trials": 1, - "srate": 100.0, - "icaweights": np.eye(4), - "icasphere": np.eye(4), - "icawinv": np.eye(4), - "icachansind": np.arange(4), - "reject": {"gcompreject": np.array([0, 1, 1, 0])}, - "etc": { - "ic_classification": { - "ICLabel": { - "classes": list(ICLABEL_CLASSES), - "classifications": np.array( - [ - [0.70, 0.10, 0.10, 0.03, 0.02, 0.03, 0.02], - [0.02, 0.94, 0.02, 0.01, 0.00, 0.00, 0.01], - [0.05, 0.02, 0.91, 0.01, 0.00, 0.00, 0.01], - [0.80, 0.05, 0.05, 0.02, 0.02, 0.03, 0.03], - ] - ), - } - } - }, - } - - -def test_eeg_icalabelstat_matches_eeglab_threshold_counts_and_prints_summary(capsys) -> None: - stats = eeg_icalabelstat(_classified_eeg(), threshold=0.9) - - assert stats["classes"] == list(ICLABEL_CLASSES) - assert stats["component_count"] == 4 - np.testing.assert_array_equal(stats["counts"], [0, 1, 1, 0, 0, 0, 0]) - assert stats["component_indices"][1] == [2] - assert stats["component_indices"][2] == [3] - np.testing.assert_array_equal(stats["rejected_counts"], [0, 1, 1, 0, 0, 0, 0]) - np.testing.assert_array_equal(stats["kept_counts"], [0, 0, 0, 0, 0, 0, 0]) - - lines = capsys.readouterr().out.splitlines() - assert len(lines) == len(ICLABEL_CLASSES) - assert lines[0].strip() == 'IClabel class "Brain": 0/4 components at 90% threshold' - assert lines[1].strip() == 'IClabel class "Muscle": 1/4 components at 90% threshold' - assert lines[2].strip() == 'IClabel class "Eye": 1/4 components at 90% threshold' - - -def test_eeg_icalabelstat_accepts_class_specific_thresholds_and_default_classes() -> None: - eeg = _classified_eeg() - eeg["etc"]["ic_classification"]["ICLabel"].pop("classes") - - stats = eeg_icalabelstat(eeg, threshold=[0.6, 0.9, 0.9, 0.5, 0.5, 0.5, 0.5], verbose=False) - - assert stats["classes"] == list(ICLABEL_CLASSES) - np.testing.assert_array_equal(stats["counts"], [2, 1, 1, 0, 0, 0, 0]) - np.testing.assert_allclose(stats["threshold"], [0.6, 0.9, 0.9, 0.5, 0.5, 0.5, 0.5]) - np.testing.assert_array_equal(stats["dominant_counts"], [2, 1, 1, 0, 0, 0, 0]) - - -def test_eeg_icalabelstat_rejects_missing_or_malformed_iclabel_state() -> None: - eeg = _classified_eeg() - eeg["etc"] = {} - - with pytest.raises(ValueError, match="No ICLabel classifications"): - eeg_icalabelstat(eeg, verbose=False) - - malformed = _classified_eeg() - malformed["etc"]["ic_classification"]["ICLabel"]["classes"] = ["Brain"] - with pytest.raises(ValueError, match="ICLabel class list has 1 labels"): - eeg_icalabelstat(malformed, verbose=False) - - -def test_sample_data_ica_iclabel_state_drives_statistics_and_viewprops_history() -> None: - eeg = pop_loadset("sample_data/eeglab_data_with_ica_tmp.set") - classifications = np.zeros((eeg["icaweights"].shape[0], len(ICLABEL_CLASSES)), dtype=float) - classifications[:, 0] = 0.8 - classifications[0, 1] = 0.95 - classifications[1, 2] = 0.96 - eeg = copy.deepcopy(eeg) - eeg.setdefault("etc", {})["ic_classification"] = { - "ICLabel": {"classes": list(ICLABEL_CLASSES), "classifications": classifications} - } - - stats = eeg_icalabelstat(eeg, threshold=0.9, verbose=False) - _figures, command = pop_viewprops(eeg, 0, [1, 2], plot=False, return_com=True) - - assert stats["component_count"] == 32 - np.testing.assert_array_equal(stats["counts"][:3], [0, 1, 1]) - assert command == "pop_viewprops(EEG, 0, [1 2], [], [], 1, '');" - - -def test_python_iclabel_rejects_unbundled_alternate_networks_before_runtime_dependencies() -> None: - eeg = _classified_eeg() - - with pytest.raises(NotImplementedError, match="standalone Python ICLabel only ships the default network"): - iclabel(eeg, algorithm="lite", engine=None) - - -def test_eeg_icalabelstat_help_is_packaged() -> None: - help_text, source_path = pophelp_text("eeg_icalabelstat") - - assert "EEG_ICALABELSTAT" in help_text - assert source_path == "eegprep/resources/help/eeg_icalabelstat.md" - - -def test_eeg_icalabelstat_is_public_lazy_export() -> None: - from eegprep import eeg_icalabelstat as exported - - assert exported is eeg_icalabelstat diff --git a/tests/test_parity_rng.py b/tests/test_parity_rng.py deleted file mode 100644 index 6d57ea54..00000000 --- a/tests/test_parity_rng.py +++ /dev/null @@ -1,335 +0,0 @@ -""" -Test parity of random number generation between Python and MATLAB. - -IMPORTANT: The RNG mechanism works as follows: -1. Both Python and MATLAB use seed 5489 (MATLAB default) -2. Both use the Mersenne Twister algorithm (MT19937) -3. The rand() uniform distribution DOES produce the same sequence -4. However, randn() normal distribution does NOT match between implementations -5. For parity, use rand() + round_mat() + custom sampling (see ransac.py:rand_sample) - -This mechanism is used throughout the codebase (e.g., clean_channels.py:111, -eeg_picard.py:47, ransac.py:9-31) to ensure reproducible results. -""" - -import os -import unittest -import numpy as np -import tempfile -import scipy.io -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab -from eegprep.plugins.clean_rawdata.private.ransac import rand_sample -from eegprep.functions.miscfunc.misc import round_mat - - -class TestRNGParity(unittest.TestCase): - """Test that Python and MATLAB produce identical random sequences using rand().""" - - def setUp(self): - """Set up test fixtures.""" - # Try to get MATLAB engine - try: - self.eeglab = get_eeglab('MAT', auto_file_roundtrip=False) - self.matlab_available = True - except Exception as e: - self.matlab_available = False - self.skipTest(f"MATLAB not available: {e}") - - def test_rng_uniform_parity(self): - """Test that rand() (uniform) produces the SAME ORDERED sequence (1D) in Python and MATLAB.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - seed = 5489 # MATLAB default - - # Python random sequence (uniform) - 1D to avoid column/row major issues - rng_py = np.random.RandomState(seed) - py_rand = rng_py.rand(50) # 1D array - - # MATLAB random sequence (uniform) - force 1D - temp_file = tempfile.mktemp(suffix='.mat') - matlab_code = f""" - rng({seed}, 'twister'); - ml_rand = rand(1, 50); % Row vector to ensure 1D - save('{temp_file}', 'ml_rand'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB results - mat_data = scipy.io.loadmat(temp_file) - ml_rand = mat_data['ml_rand'].flatten() - - # Clean up - os.remove(temp_file) - - # Compare rand (uniform distribution) - THIS SHOULD MATCH for 1D - print("\nrand() 1D uniform comparison:") - print(f" Python first 5 values: {py_rand[:5]}") - print(f" MATLAB first 5 values: {ml_rand[:5]}") - print(f" Max absolute diff: {np.max(np.abs(py_rand - ml_rand)):.2e}") - - np.testing.assert_allclose( - py_rand, ml_rand, rtol=1e-15, atol=1e-15, err_msg="rand() 1D uniform should produce identical sequences" - ) - - def test_rng_normal_incompatibility(self): - """Test that randn() (normal) produces DIFFERENT sequences (known incompatibility).""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - seed = 5489 - - # Python random sequence (normal) - rng_py = np.random.RandomState(seed) - py_randn = rng_py.randn(10, 5) - - # MATLAB random sequence (normal) - temp_file = tempfile.mktemp(suffix='.mat') - matlab_code = f""" - rng({seed}, 'twister'); - ml_randn = randn(10, 5); - save('{temp_file}', 'ml_randn'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB results - mat_data = scipy.io.loadmat(temp_file) - ml_randn = mat_data['ml_randn'] - - # Clean up - os.remove(temp_file) - - # Compare randn (normal distribution) - THIS SHOULD DIFFER - print("\nrandn() normal comparison (EXPECTED TO DIFFER):") - print(f" Python first 3 values: {py_randn.flatten()[:3]}") - print(f" MATLAB first 3 values: {ml_randn.flatten()[:3]}") - print(f" Max absolute diff: {np.max(np.abs(py_randn - ml_randn)):.2e}") - - are_different = not np.allclose(py_randn, ml_randn, rtol=1e-10, atol=1e-10) - self.assertTrue( - are_different, - "randn() normal distribution differs between Python and MATLAB (this is expected - use rand() for parity)", - ) - - def test_rand_sample_mechanism(self): - """Test the rand_sample mechanism that provides MATLAB parity.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - seed = 5489 - n = 20 # sample from 20 items - m = 5 # select 5 items - - # Python rand_sample (from ransac.py) - rng_py = np.random.RandomState(seed) - py_sample = rand_sample(n, m, rng_py) - - # MATLAB equivalent (Fisher-Yates shuffle to match Python implementation) - temp_file = tempfile.mktemp(suffix='.mat') - matlab_code = f""" - rng({seed}, 'twister'); - n = {n}; - m = {m}; - pool = 0:(n-1); % 0-indexed to match Python - - % Fisher-Yates shuffle (matches Python rand_sample implementation) - for k = 1:m - python_k = k - 1; % Convert to 0-indexed - remaining = n - python_k; - choice = round((remaining - 1) * rand()); - idx = k + choice; % k is 1-indexed, choice is 0-indexed offset - - % Swap pool(k) with pool(idx) - temp = pool(k); - pool(k) = pool(idx); - pool(idx) = temp; - end - - ml_sample = pool(1:m); % First m elements - save('{temp_file}', 'ml_sample'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB results - mat_data = scipy.io.loadmat(temp_file) - ml_sample = mat_data['ml_sample'].flatten().astype(int) - - # Clean up - os.remove(temp_file) - - # Compare - print("\nrand_sample comparison:") - print(f" Python sample: {py_sample}") - print(f" MATLAB sample: {ml_sample}") - - np.testing.assert_array_equal(py_sample, ml_sample, err_msg="rand_sample should produce identical results") - - def test_round_mat_parity(self): - """Test that round_mat matches MATLAB's round() behavior.""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - # Test values including tie-breaking cases - test_values = np.array([0.5, 1.5, 2.5, -0.5, -1.5, -2.5, 0.49, 0.51]) - - # Python round_mat - py_rounded = np.array([round_mat(x) for x in test_values]) - - # MATLAB round - temp_file = tempfile.mktemp(suffix='.mat') - matlab_code = f""" - test_values = {test_values.tolist()}; - ml_rounded = round(test_values); - save('{temp_file}', 'ml_rounded'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB results - mat_data = scipy.io.loadmat(temp_file) - ml_rounded = mat_data['ml_rounded'].flatten() - - # Clean up - os.remove(temp_file) - - # Compare - print("\nround_mat comparison:") - print(f" Test values: {test_values}") - print(f" Python round: {py_rounded}") - print(f" MATLAB round: {ml_rounded}") - - np.testing.assert_array_equal(py_rounded, ml_rounded, err_msg="round_mat should match MATLAB round()") - - def test_rng_permutation_compatibility(self): - """Test that permutation differs (known incompatibility).""" - if not self.matlab_available: - self.skipTest("MATLAB not available") - - seed = 5489 - n = 20 - - # Python permutation - rng_py = np.random.RandomState(seed) - py_perm = rng_py.permutation(n) - - # MATLAB permutation - temp_file = tempfile.mktemp(suffix='.mat') - matlab_code = f""" - rng({seed}, 'twister'); - ml_perm = randperm({n}); - save('{temp_file}', 'ml_perm'); - """ - self.eeglab.eval(matlab_code, nargout=0) - - # Load MATLAB results - mat_data = scipy.io.loadmat(temp_file) - ml_perm = mat_data['ml_perm'].flatten() - - # Clean up - os.remove(temp_file) - - # Check if they differ (they should - different algorithms) - are_different = not np.array_equal(py_perm, ml_perm) - - print("\nPermutation comparison (EXPECTED TO DIFFER):") - print(f" Python permutation: {py_perm[:5]}...") - print(f" MATLAB permutation: {ml_perm[:5]}...") - print(f" Are different: {are_different}") - - # Document this known incompatibility - self.assertTrue( - are_different, - "Permutation algorithms differ between Python and MATLAB " - "(this is expected - randperm uses different algorithm than permutation)", - ) - - -class TestRNGIsolation(unittest.TestCase): - """Demonstrate RNG mechanism without MATLAB dependency.""" - - def test_python_rng_deterministic(self): - """Test that Python RNG is deterministic with same seed.""" - seed = 5489 - - # First run - rng1 = np.random.RandomState(seed) - seq1 = rng1.rand(100) # Use rand() for consistency - - # Second run with same seed - rng2 = np.random.RandomState(seed) - seq2 = rng2.rand(100) - - # Should be identical - np.testing.assert_array_equal(seq1, seq2, err_msg="Same seed should produce identical sequence") - - def test_python_rng_different_seeds(self): - """Test that different seeds produce different sequences.""" - rng1 = np.random.RandomState(5489) - seq1 = rng1.rand(100) - - rng2 = np.random.RandomState(12345) - seq2 = rng2.rand(100) - - # Should be different - self.assertFalse(np.array_equal(seq1, seq2), "Different seeds should produce different sequences") - - def test_matlab_default_seed_value(self): - """Document that 5489 is MATLAB's default seed.""" - # This is the seed value used throughout the codebase: - # - clean_channels.py:111 - # - eeg_picard.py:47 - # - Corresponds to MATLAB's rng('default') - - matlab_default_seed = 5489 - - # Create RNG with this seed - rng = np.random.RandomState(matlab_default_seed) - first_uniform_value = rng.rand() - - # Expected first value when using seed 5489 - # (Verified against MATLAB: rng(5489,'twister'); rand) - expected_first_uniform_value = 0.8147236863931789 # Matches MATLAB output - self.assertAlmostEqual(first_uniform_value, expected_first_uniform_value) - - # Reset and check - rng = np.random.RandomState(matlab_default_seed) - actual_first_uniform_value = rng.rand() - - self.assertAlmostEqual( - actual_first_uniform_value, - expected_first_uniform_value, - places=15, - msg="MATLAB default seed (5489) should produce expected first uniform value", - ) - - def test_rand_sample_deterministic(self): - """Test that rand_sample is deterministic.""" - seed = 5489 - n = 20 - m = 5 - - # First run - rng1 = np.random.RandomState(seed) - sample1 = rand_sample(n, m, rng1) - - # Second run - rng2 = np.random.RandomState(seed) - sample2 = rand_sample(n, m, rng2) - - # Should be identical - np.testing.assert_array_equal(sample1, sample2, err_msg="rand_sample with same seed should be deterministic") - - def test_round_mat_tie_breaking(self): - """Test round_mat's tie-breaking behavior (rounds away from zero).""" - # MATLAB rounds ties (.5) away from zero - # Python's round() rounds ties to even (banker's rounding) - # round_mat should match MATLAB - - self.assertEqual(round_mat(0.5), 1.0) # Round up - self.assertEqual(round_mat(-0.5), -1.0) # Round down (away from zero) - self.assertEqual(round_mat(1.5), 2.0) # Round up - self.assertEqual(round_mat(-1.5), -2.0) # Round down (away from zero) - self.assertEqual(round_mat(2.5), 3.0) # Round up - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_phase4_plot_wrappers.py b/tests/test_phase4_plot_wrappers.py deleted file mode 100644 index 118f3140..00000000 --- a/tests/test_phase4_plot_wrappers.py +++ /dev/null @@ -1,1349 +0,0 @@ -from __future__ import annotations - -import ast -from copy import deepcopy -import importlib -import os -from pathlib import Path -from typing import Any - -import matplotlib - -matplotlib.use("Agg") - -import matplotlib.pyplot as plt -import numpy as np -import pytest -import scipy.io - -from eegprep.functions.guifunc.qt import QtDialogRenderer -from eegprep.functions.guifunc.spec import controls_by_tag -from eegprep.functions.popfunc.pop_comperp import pop_comperp, pop_comperp_dialog_spec -from eegprep.functions.popfunc.pop_envtopo import pop_envtopo -from eegprep.functions.popfunc.pop_erpimage import pop_erpimage, pop_erpimage_dialog_spec -from eegprep.functions.popfunc.pop_headplot import ( - pop_headplot, - pop_headplot_dialog_spec, - _current_spline_file, - _shared_maplimits, -) -from eegprep.functions.popfunc.pop_loadset import pop_loadset -from eegprep.functions.popfunc.pop_epoch import pop_epoch -from eegprep.functions.popfunc._plot_utils import component_activations -from eegprep.functions.popfunc.pop_plotdata import pop_plotdata -from eegprep.functions.popfunc.pop_plottopo import pop_plottopo, pop_plottopo_dialog_spec -from eegprep.functions.popfunc.pop_prop import pop_prop, pop_prop_dialog_spec -from eegprep.functions.popfunc.pop_signalstat import pop_signalstat -from eegprep.functions.popfunc.pop_spectopo import pop_spectopo -from eegprep.functions.popfunc.pop_timtopo import pop_timtopo -from eegprep.functions.popfunc.pop_topoplot import pop_topoplot -from eegprep.functions.studyfunc.pop_chanplot import pop_chanplot, pop_chanplot_dialog_spec -from eegprep.functions.sigprocfunc.coregister import ( - ElectrodeSet, - apply_coregistration_transform, - coregister, - estimate_coregistration_transform, - load_coregistration_electrodes, - match_electrodes, - read_electrode_file, - traditional_transform_matrix, -) -from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops -from eegprep.functions.sigprocfunc.headplot import ( - MAPLIMIT_PADDING, - default_headplot_mesh_transform, - headplot, - headplot_setup, - load_headplot_spline, - packaged_headplot_path, - _interpolate_values, -) -from tests.fixtures import SAMPLE_DATASET_PATH, create_test_eeg_with_ica - - -@pytest.fixture(scope="module") -def sample_eeg(): - return pop_loadset(SAMPLE_DATASET_PATH) - - -@pytest.fixture(scope="module") -def sample_epoch(sample_eeg): - epoched, _command = pop_epoch(deepcopy(sample_eeg), ["square"], [-0.1, 0.2], return_com=True) - return epoched - - -@pytest.fixture -def ica_epoch(): - return create_test_eeg_with_ica(n_channels=6, n_samples=40, n_trials=4, n_components=4) - - -def test_pop_spectopo_plots_sample_data_headlessly(sample_eeg): - result, command = pop_spectopo(sample_eeg, dataflag=1, freqs=[6, 10], return_com=True) - - assert result["spectra"].shape[0] == sample_eeg["nbchan"] - assert result["freqs"].ndim == 1 - assert np.isfinite(result["spectra"]).all() - assert result["figure"] is not None - assert "pop_spectopo(EEG" in command - _assert_python_command(command) - plt.close(result["figure"]) - - -def test_pop_spectopo_component_default_controls_succeed(ica_epoch): - result, command = pop_spectopo( - ica_epoch, dataflag=0, freqs=[10], plotchan=0, icamode=True, icacomps=[1, 2], nicamaps=2, return_com=True - ) - - assert result["spectra"].shape[0] == 2 - assert "pop_spectopo(EEG" in command - plt.close(result["figure"]) - - -def test_pop_spectopo_rejects_nondefault_plotchan(ica_epoch): - with pytest.raises(ValueError, match="whole-scalp component spectra"): - pop_spectopo(ica_epoch, dataflag=0, freqs=[10], plotchan=3, icacomps=[1, 2]) - - -def test_pop_spectopo_rejects_max_power_plotchan(ica_epoch): - with pytest.raises(ValueError, match="whole-scalp component spectra"): - pop_spectopo(ica_epoch, dataflag=0, freqs=[10], plotchan=[], icacomps=[1, 2]) - - -def test_pop_spectopo_rejects_datacomp_icamode(ica_epoch): - with pytest.raises(ValueError, match="component spectra"): - pop_spectopo(ica_epoch, dataflag=0, freqs=[10], icamode=False, icacomps=[1, 2]) - - -def test_pop_prop_plots_sample_channel_properties(sample_eeg): - figure, command = pop_prop(sample_eeg, typecomp=1, chanorcomp=1, return_com=True) - - assert len(figure.axes) >= 3 - assert "pop_prop(EEG" in command - _assert_python_command(command) - plt.close(figure) - - -def test_pop_headplot_plots_sample_latency_map_with_spline_setup(sample_eeg, tmp_path): - eeg = deepcopy(sample_eeg) - splinefile = tmp_path / "sample.spl" - setup = {"splinefile": str(splinefile), "transform": [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100]} - - figures, command = pop_headplot(eeg, typeplot=1, items=[0], setup=setup, return_com=True) - - assert len(figures) == 1 - assert figures[0].axes[0].name == "3d" - assert splinefile.exists() - assert _current_spline_file(eeg, 1) != str(splinefile) # plot must not mutate the caller's EEG - assert "setup={" in command - _assert_python_command(command) - plt.close(figures[0]) - - -def test_pop_headplot_does_not_mutate_caller_eeg(sample_eeg, tmp_path): - eeg = deepcopy(sample_eeg) - before = deepcopy(eeg) - setup = { - "splinefile": str(tmp_path / "nomutate.spl"), - "transform": [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100], - } - - figures = pop_headplot(eeg, typeplot=1, items=[0], setup=setup) - - assert set(eeg.keys()) == set(before.keys()) # no new spline/mesh keys added - assert np.array_equal(np.asarray(eeg["splinefile"]), np.asarray(before["splinefile"])) - assert "headplotmeshfile" not in eeg - assert np.array_equal(np.asarray(eeg["data"]), np.asarray(before["data"])) - for fig in figures: - plt.close(fig) - - -def test_pop_headplot_single_map_has_eeglab_like_title_and_surface(sample_eeg, tmp_path): - eeg = deepcopy(sample_eeg) - title = "ERP scalp maps of dataset: eeglab_data" - splinefile = tmp_path / "single_map.spl" - - figures, _command = pop_headplot( - eeg, - typeplot=1, - items=[0], - topotitle=title, - setup={"splinefile": str(splinefile), "transform": [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100]}, - colorbar="off", - return_com=True, - ) - - figure = figures[0] - assert figure._suptitle is not None - assert figure._suptitle.get_text() == title - assert figure.axes[0].get_title() == "" - width, height = figure.get_size_inches() - assert width >= 8.0 - assert height >= 6.0 - assert figure.axes[0].get_position().width > 0.6 - facecolors = np.asarray(figure.axes[0].collections[0].get_facecolors()) - assert np.isfinite(facecolors).all() - assert np.all(facecolors[:, 3] == 1) - plt.close(figure) - - -def test_headplot_setup_file_can_be_reused_for_sample_data(sample_eeg, tmp_path): - splinefile = tmp_path / "reuse.spl" - transform = [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100] - - created = headplot_setup(sample_eeg["chanlocs"], splinefile, chaninfo=sample_eeg["chaninfo"], transform=transform) - spline = load_headplot_spline(created) - figure = headplot(np.nanmean(np.asarray(sample_eeg["data"], dtype=float), axis=1), created, title="Sample") - - assert spline.g.shape == (sample_eeg["nbchan"], sample_eeg["nbchan"]) - assert spline.gx.shape[1] == sample_eeg["nbchan"] - assert len(figure.axes) >= 1 - plt.close(figure) - - -def test_headplot_setup_plotmeshonly_and_orilocs_options(sample_eeg, tmp_path): - transform = [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100] - preview_file = tmp_path / "preview.spl" - - preview = headplot_setup( - sample_eeg["chanlocs"], - preview_file, - chaninfo=sample_eeg["chaninfo"], - transform=transform, - plotmeshonly="sphere", - ) - - assert preview.axes[0].name == "3d" - assert not preview_file.exists() - plt.close(preview) - - created = headplot_setup( - sample_eeg["chanlocs"], - tmp_path / "orilocs.spl", - chaninfo=sample_eeg["chaninfo"], - transform=transform, - orilocs="on", - ) - spline = load_headplot_spline(created) - np.testing.assert_allclose(spline.new_electrodes, np.column_stack([spline.xe, spline.ye, spline.ze])) - - -def test_pop_headplot_setup_reuses_existing_spline_file(sample_eeg, tmp_path): - eeg = deepcopy(sample_eeg) - splinefile = tmp_path / "existing.spl" - original_transform = [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100] - replay_transform = [1, 2, 3, 0.1, 0.2, 0.3, 4, 5, 6] - headplot_setup(sample_eeg["chanlocs"], splinefile, chaninfo=sample_eeg["chaninfo"], transform=original_transform) - - figures, command = pop_headplot( - eeg, - typeplot=1, - items=[0], - setup={"splinefile": str(splinefile), "transform": replay_transform}, - return_com=True, - ) - - np.testing.assert_allclose(load_headplot_spline(splinefile).transform, original_transform) - assert "setup={" in command - plt.close(figures[0]) - - -def test_headplot_setup_falls_back_when_xyz_fields_are_empty_arrays(sample_eeg, tmp_path): - eeg = deepcopy(sample_eeg) - eeg["chanlocs"][0]["X"] = np.asarray([]) - eeg["chanlocs"][0]["Y"] = np.asarray([]) - eeg["chanlocs"][0]["Z"] = np.asarray([]) - - created = headplot_setup( - eeg["chanlocs"], - tmp_path / "empty_xyz.spl", - chaninfo=eeg["chaninfo"], - transform=[0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100], - ) - - assert created.exists() - - -def test_coregister_reads_packaged_reference_and_matches_sample_labels(sample_eeg): - source = load_coregistration_electrodes(sample_eeg["chanlocs"], chaninfo=sample_eeg["chaninfo"]) - target = read_electrode_file(packaged_headplot_path("mheadnew.xyz")) - - source_indices, target_indices, labels = match_electrodes(source, target) - transform = estimate_coregistration_transform(source, target, method="traditional") - transformed = apply_coregistration_transform(source.points[source_indices], transform) - - assert len(labels) >= 20 - assert source_indices.shape == target_indices.shape - assert transform.shape == (9,) - assert np.isfinite(transformed).all() - - -def test_coregister_sample_warp_reports_positive_resize_fields(sample_eeg): - source = load_coregistration_electrodes(sample_eeg["chanlocs"], chaninfo=sample_eeg["chaninfo"]) - target = read_electrode_file(packaged_headplot_path("mheadnew.xyz")) - - transform = estimate_coregistration_transform(source, target, method="traditional") - - assert np.all(transform[6:9] > 0) - assert np.all(np.abs(transform[3:6]) <= np.pi) - - -def test_coregister_fits_traditional_and_shared_scale_transforms(): - source = ElectrodeSet( - ["Nz", "LPA", "RPA", "Cz", "Pz"], - np.asarray( - [ - [0.0, 1.0, 0.0], - [-1.0, 0.0, -0.2], - [1.0, 0.0, -0.2], - [0.0, 0.0, 1.0], - [0.0, -1.0, 0.0], - ], - dtype=float, - ), - ) - traditional = np.asarray([5.0, -3.0, 2.0, 0.05, -0.04, 0.03, 2.0, 1.5, 1.2]) - target = ElectrodeSet(source.labels, apply_coregistration_transform(source.points, traditional)) - - fitted = estimate_coregistration_transform(source, target, method="traditional") - shared = estimate_coregistration_transform( - source, - ElectrodeSet( - source.labels, apply_coregistration_transform(source.points, [1, 2, 3, 0.02, 0.01, -0.02, 2, 2, 2]) - ), - method="globalrescale", - ) - - np.testing.assert_allclose(apply_coregistration_transform(source.points, fitted), target.points, atol=1e-6) - np.testing.assert_allclose(shared[6], shared[7]) - np.testing.assert_allclose(shared[7], shared[8]) - - with pytest.raises(ValueError, match="traditional.*globalrescale"): - estimate_coregistration_transform(source, target, method="nonlin") - - -def test_coregister_noninteractive_path_returns_transformed_electrodes(sample_eeg): - result = coregister( - sample_eeg["chanlocs"], - packaged_headplot_path("mheadnew.xyz"), - chaninfo1=sample_eeg["chaninfo"], - warp="auto", - ) - - assert result.electrodes.points.shape[1] == 3 - assert result.transform.shape == (9,) - assert np.isfinite(result.electrodes.points).all() - - -def test_pop_headplot_component_path_creates_ica_spline(ica_epoch, tmp_path): - eeg = deepcopy(ica_epoch) - for index, chanloc in enumerate(eeg["chanlocs"]): - chanloc.setdefault("X", float(index - 2)) - chanloc.setdefault("Y", float(index % 3)) - chanloc.setdefault("Z", float(30 + index)) - splinefile = tmp_path / "component.spl" - - figures, command = pop_headplot( - eeg, - typeplot=0, - items=[1, -2], - setup={"splinefile": str(splinefile), "transform": [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100]}, - return_com=True, - ) - - assert len(figures) == 1 - assert _current_spline_file(eeg, 0) != str(splinefile) # plot must not mutate the caller's EEG - assert "IC 2" in figures[0].axes[1].get_title() - _assert_python_command(command) - plt.close(figures[0]) - - -def test_pop_headplot_component_path_skips_sample_channels_without_locations(tmp_path): - eeg = pop_loadset(SAMPLE_DATASET_PATH.parent / "eeglab_data_with_ica_tmp.set") - splinefile = tmp_path / "sample_component.spl" - - figures, command = pop_headplot( - eeg, - typeplot=0, - items=[1], - setup={"splinefile": str(splinefile), "transform": [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100]}, - return_com=True, - ) - spline = load_headplot_spline(splinefile) - - assert len(figures) == 1 - assert spline.xe.size < eeg["nbchan"] - assert np.max(spline.indices) < np.asarray(eeg["icawinv"]).shape[0] - _assert_python_command(command) - plt.close(figures[0]) - - -def test_pop_headplot_shared_string_maplimits_are_replayable(ica_epoch, tmp_path): - eeg = deepcopy(ica_epoch) - for index, chanloc in enumerate(eeg["chanlocs"]): - chanloc.setdefault("X", float(index - 2)) - chanloc.setdefault("Y", float(index % 3)) - chanloc.setdefault("Z", float(30 + index)) - splinefile = tmp_path / "maplimits.spl" - - figures, command = pop_headplot( - eeg, - typeplot=0, - items=[1, 2], - setup={"splinefile": str(splinefile), "transform": [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100]}, - maplimits="absmax", - return_com=True, - ) - - assert len(figures) == 1 - assert "maplimits='absmax'" in command - _assert_python_command(command) - plt.close(figures[0]) - - -def test_headplot_maplimits_use_eeglab_padding(ica_epoch): - values = np.asarray(ica_epoch["icawinv"], dtype=float)[:, 0] - expected = float(np.nanmax(np.abs(values)) * MAPLIMIT_PADDING) - np.testing.assert_allclose(_shared_maplimits([values], "absmax"), [-expected, expected]) - - -def test_pop_headplot_gui_setup_result_is_replayable(sample_eeg, tmp_path): - class Renderer: - def __init__(self): - self.spec = None - - def run(self, spec, initial_values=None): - self.spec = spec - return { - "loadcb": False, - "compcb": True, - "setup_file": str(tmp_path / "gui.spl"), - "meshfile": 1, - "meshchanfile": 1, - "transform": "0 -10 0 -0.1 0 -1.6 1100 1100 1100", - "items": "0", - "topotitle": "GUI setup", - "rowcols": "", - "options": "'electrodes', 'off'", - } - - eeg = deepcopy(sample_eeg) - renderer = Renderer() - figures, command = pop_headplot(eeg, typeplot=1, gui=True, renderer=renderer, return_com=True) - - assert renderer.spec.title == "ERP head plot(s) -- pop_headplot()" - assert (tmp_path / "gui.spl").exists() - assert _current_spline_file(eeg, 1) == "" # plot must not mutate the caller's EEG - assert "setup={" in command - assert "electrodes='off'" in command - _assert_python_command(command) - replay_namespace = {"EEG": deepcopy(sample_eeg), "pop_headplot": pop_headplot} - exec(command, replay_namespace) - plt.close(figures[0]) - plt.close("all") - - -@pytest.mark.matlab -def test_headplot_setup_spline_metadata_matches_eeglab(sample_eeg, tmp_path): - if os.environ.get("EEGPREP_SKIP_MATLAB") == "1": - pytest.skip("MATLAB tests disabled via EEGPREP_SKIP_MATLAB") - try: - matlab_engine = importlib.import_module("matlab.engine") - except ImportError as exc: - pytest.skip(f"MATLAB not available: {exc}") - eeglab_root = _eeglab_reference_root() - if not eeglab_root.exists(): - pytest.skip("EEGLAB reference checkout not available") - - transform = [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100] - python_spline = tmp_path / "python.spl" - matlab_spline = tmp_path / "matlab.spl" - matlab_output = tmp_path / "headplot_setup.mat" - headplot_setup(sample_eeg["chanlocs"], python_spline, chaninfo=sample_eeg["chaninfo"], transform=transform) - - engine = matlab_engine.start_matlab() - try: - for relative in ( - "functions/guifunc", - "functions/popfunc", - "functions/adminfunc", - "functions/sigprocfunc", - "functions/miscfunc", - "functions/supportfiles", - "plugins/dipfit", - ): - engine.addpath(str(eeglab_root / relative), nargout=0) - engine.eval( - f""" - EEG = pop_loadset('{_matlab_string(SAMPLE_DATASET_PATH)}'); - headplot('setup', EEG.chanlocs, '{_matlab_string(matlab_spline)}', ... - 'chaninfo', EEG.chaninfo, 'meshfile', 'mheadnew.mat', ... - 'transform', [{_matlab_vector(transform)}]); - S = load('{_matlab_string(matlab_spline)}', '-mat'); - G = S.G; gx = S.gx; Xe = S.Xe; Ye = S.Ye; Ze = S.Ze; newElect = S.newElect; - indices = S.indices; transform = S.transform; - save('{_matlab_string(matlab_output)}', 'G', 'gx', 'Xe', 'Ye', 'Ze', 'newElect', 'indices', 'transform'); - """, - nargout=0, - ) - finally: - engine.quit() - - py_spline = load_headplot_spline(python_spline) - ml = scipy.io.loadmat(matlab_output, squeeze_me=True) - np.testing.assert_allclose(py_spline.g, ml["G"], rtol=1e-10, atol=1e-10) - np.testing.assert_allclose(py_spline.gx, ml["gx"], rtol=1e-10, atol=1e-10) - np.testing.assert_allclose(py_spline.xe, ml["Xe"], rtol=1e-10, atol=1e-10) - np.testing.assert_allclose(py_spline.ye, ml["Ye"], rtol=1e-10, atol=1e-10) - np.testing.assert_allclose(py_spline.ze, ml["Ze"], rtol=1e-10, atol=1e-10) - np.testing.assert_allclose(py_spline.new_electrodes, ml["newElect"], rtol=1e-10, atol=1e-10) - np.testing.assert_array_equal(py_spline.indices + 1, np.asarray(ml["indices"], dtype=int).ravel()) - np.testing.assert_allclose(py_spline.transform, np.asarray(ml["transform"], dtype=float).ravel()) - - -@pytest.mark.matlab -def test_headplot_interpolated_values_match_eeglab(sample_eeg, tmp_path): - if os.environ.get("EEGPREP_SKIP_MATLAB") == "1": - pytest.skip("MATLAB tests disabled via EEGPREP_SKIP_MATLAB") - try: - matlab_engine = importlib.import_module("matlab.engine") - except ImportError as exc: - pytest.skip(f"MATLAB not available: {exc}") - eeglab_root = _eeglab_reference_root() - if not eeglab_root.exists(): - pytest.skip("EEGLAB reference checkout not available") - - transform = [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100] - values = np.linspace(-4.0, 6.0, int(sample_eeg["nbchan"])) - python_spline = tmp_path / "python_interp.spl" - matlab_spline = tmp_path / "matlab_interp.spl" - matlab_output = tmp_path / "headplot_interp.mat" - headplot_setup(sample_eeg["chanlocs"], python_spline, chaninfo=sample_eeg["chaninfo"], transform=transform) - - engine = matlab_engine.start_matlab() - try: - for relative in ( - "functions/guifunc", - "functions/popfunc", - "functions/adminfunc", - "functions/sigprocfunc", - "functions/miscfunc", - "functions/supportfiles", - "plugins/dipfit", - ): - engine.addpath(str(eeglab_root / relative), nargout=0) - engine.eval( - f""" - EEG = pop_loadset('{_matlab_string(SAMPLE_DATASET_PATH)}'); - headplot('setup', EEG.chanlocs, '{_matlab_string(matlab_spline)}', ... - 'chaninfo', EEG.chaninfo, 'meshfile', 'mheadnew.mat', ... - 'transform', [{_matlab_vector(transform)}]); - S = load('{_matlab_string(matlab_spline)}', '-mat'); - values = [{_matlab_vector(values)}]'; - meanval = mean(values); - centered = values - meanval; - enum = length(values); - lamd = 0.1; - C = pinv([(S.G + lamd); ones(1, enum)]) * [centered(:); 0]; - P = S.gx * C + meanval; - save('{_matlab_string(matlab_output)}', 'P'); - """, - nargout=0, - ) - finally: - engine.quit() - - py_values = _interpolate_values(values, load_headplot_spline(python_spline)) - ml = scipy.io.loadmat(matlab_output, squeeze_me=True) - np.testing.assert_allclose(py_values, np.asarray(ml["P"], dtype=float).ravel(), rtol=1e-10, atol=1e-10) - - -@pytest.mark.matlab -def test_headplot_setup_ica_metadata_matches_eeglab(tmp_path): - if os.environ.get("EEGPREP_SKIP_MATLAB") == "1": - pytest.skip("MATLAB tests disabled via EEGPREP_SKIP_MATLAB") - try: - matlab_engine = importlib.import_module("matlab.engine") - except ImportError as exc: - pytest.skip(f"MATLAB not available: {exc}") - eeglab_root = _eeglab_reference_root() - if not eeglab_root.exists(): - pytest.skip("EEGLAB reference checkout not available") - - sample_ica = SAMPLE_DATASET_PATH.parent / "eeglab_data_with_ica_tmp.set" - transform = [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100] - eeg = pop_loadset(sample_ica) - python_spline = tmp_path / "python_ica.spl" - matlab_spline = tmp_path / "matlab_ica.spl" - matlab_output = tmp_path / "headplot_ica.mat" - headplot_setup(eeg["chanlocs"], python_spline, chaninfo=eeg["chaninfo"], transform=transform, ica="on") - - engine = matlab_engine.start_matlab() - try: - for relative in ( - "functions/guifunc", - "functions/popfunc", - "functions/adminfunc", - "functions/sigprocfunc", - "functions/miscfunc", - "functions/supportfiles", - "plugins/dipfit", - ): - engine.addpath(str(eeglab_root / relative), nargout=0) - engine.eval( - f""" - EEG = pop_loadset('{_matlab_string(sample_ica)}'); - headplot('setup', EEG.chanlocs, '{_matlab_string(matlab_spline)}', ... - 'chaninfo', EEG.chaninfo, 'ica', 'on', 'meshfile', 'mheadnew.mat', ... - 'transform', [{_matlab_vector(transform)}]); - S = load('{_matlab_string(matlab_spline)}', '-mat'); - G = S.G; gx = S.gx; Xe = S.Xe; indices = S.indices; transform = S.transform; - save('{_matlab_string(matlab_output)}', 'G', 'gx', 'Xe', 'indices', 'transform'); - """, - nargout=0, - ) - finally: - engine.quit() - - py_spline = load_headplot_spline(python_spline) - ml = scipy.io.loadmat(matlab_output, squeeze_me=True) - assert py_spline.g.shape == np.asarray(ml["G"]).shape - assert py_spline.gx.shape == np.asarray(ml["gx"]).shape - np.testing.assert_allclose(py_spline.xe, np.asarray(ml["Xe"], dtype=float).ravel(), rtol=1e-10, atol=1e-10) - np.testing.assert_array_equal(py_spline.indices + 1, np.asarray(ml["indices"], dtype=int).ravel()) - np.testing.assert_allclose(py_spline.transform, np.asarray(ml["transform"], dtype=float).ravel()) - - -@pytest.mark.matlab -def test_traditional_transform_matrix_matches_eeglab(tmp_path): - if os.environ.get("EEGPREP_SKIP_MATLAB") == "1": - pytest.skip("MATLAB tests disabled via EEGPREP_SKIP_MATLAB") - try: - matlab_engine = importlib.import_module("matlab.engine") - except ImportError as exc: - pytest.skip(f"MATLAB not available: {exc}") - eeglab_root = _eeglab_reference_root() - if not eeglab_root.exists(): - pytest.skip("EEGLAB reference checkout not available") - - transform = [5, -3, 2, 0.05, -0.04, 0.03, 2, 1.5, 1.2] - matlab_output = tmp_path / "traditionaldipfit.mat" - engine = matlab_engine.start_matlab() - try: - engine.addpath(str(eeglab_root / "plugins" / "dipfit"), nargout=0) - engine.eval( - f""" - H = traditionaldipfit([{_matlab_vector(transform)}]); - save('{_matlab_string(matlab_output)}', 'H'); - """, - nargout=0, - ) - finally: - engine.quit() - - ml = scipy.io.loadmat(matlab_output, squeeze_me=True) - np.testing.assert_allclose(traditional_transform_matrix(transform), ml["H"], rtol=1e-12, atol=1e-12) - - -def test_channel_erp_plot_wrappers_work_on_sample_epochs(sample_epoch): - timtopo_fig, timtopo_command = pop_timtopo(sample_epoch, plottimes=[0], return_com=True) - plottopo_fig, plottopo_command = pop_plottopo(sample_epoch, chans=[1, 2], return_com=True) - erpimage_result, erpimage_command = pop_erpimage(sample_epoch, typeplot=1, index=1, return_com=True) - - assert len(timtopo_fig.axes) >= 2 - assert len(plottopo_fig.axes) >= 2 - assert erpimage_result["image"].shape[0] == sample_epoch["trials"] - for command in (timtopo_command, plottopo_command, erpimage_command): - _assert_python_command(command) - plt.close(timtopo_fig) - plt.close(plottopo_fig) - plt.close(erpimage_result["figure"]) - - -def test_component_plot_wrappers_work_when_ica_fields_exist(ica_epoch): - spectopo_result, spectopo_command = pop_spectopo(ica_epoch, dataflag=0, freqs=[10], return_com=True) - plotdata_fig, plotdata_command = pop_plotdata(ica_epoch, components=[1, 2], return_com=True) - envtopo_fig, envtopo_command = pop_envtopo(ica_epoch, components=[1, 2], return_com=True) - erpimage_result, erpimage_command = pop_erpimage(ica_epoch, typeplot=0, index=1, return_com=True) - - assert spectopo_result["spectra"].shape[0] == 4 - assert len(plotdata_fig.axes) >= 2 - assert len(envtopo_fig.axes) >= 2 - assert erpimage_result["image"].shape[0] == 4 - for command in (spectopo_command, plotdata_command, envtopo_command, erpimage_command): - _assert_python_command(command) - plt.close(spectopo_result["figure"]) - plt.close(plotdata_fig) - plt.close(envtopo_fig) - plt.close(erpimage_result["figure"]) - - -def test_pop_prop_attaches_component_activity_browser_model(ica_epoch): - figure, command = pop_prop(ica_epoch, typecomp=0, chanorcomp=1, return_com=True) - - activity = figure.eegprep_activity_view - - assert activity.data.mode == "epoched" - assert activity.data.n_channels == 1 - assert activity.state.title == "Scrolling IC1 Activity -- eegplot()" - assert "pop_prop(EEG" in command - plt.close(figure) - - -def test_pop_viewprops_attaches_channel_activity_browser_models(ica_epoch): - eeg = deepcopy(ica_epoch) - eeg["event"] = [{"type": "stim", "latency": 2}] - - figures, command = pop_viewprops(eeg, typecomp=1, chanorcomp=[1], scroll_event=0, return_com=True) - activity = figures[0].eegprep_activity_views[0] - - assert activity.data.mode == "epoched" - assert activity.data.channel_labels == ("Ch1",) - assert activity.state.events == [] - assert "pop_viewprops(EEG" in command - plt.close(figures[0]) - - -def test_pop_erpimage_projects_components_to_selected_channel(ica_epoch): - result, command = pop_erpimage(ica_epoch, typeplot=0, index=1, projchan=[2], return_com=True) - - expected = component_activations(ica_epoch)[0].T * np.asarray(ica_epoch["icawinv"], dtype=float)[1, 0] - np.testing.assert_allclose(result["image"], expected) - assert "projchan=[2]" in command - _assert_python_command(command) - plt.close(result["figure"]) - - -def test_phase4_dialog_specs_match_eeglab_selector_layouts(sample_eeg, ica_epoch): - prop_controls = controls_by_tag(pop_prop_dialog_spec(sample_eeg, typecomp=1)) - assert prop_controls["chanorcomp_button"].callback is not None - assert prop_controls["chanorcomp_button"].callback.params["return_indices"] is True - - erpimage_spec = pop_erpimage_dialog_spec(ica_epoch, typeplot=0) - assert erpimage_spec.size == (1113, 831) - assert erpimage_spec.row_spacing == 4 - assert erpimage_spec.geometry[0] == (1, 1, 0.1, 0.8, 2.1) - assert erpimage_spec.geometry[1] == (1, 1, 0.4, 0.5, 2.1) - assert [control.tag for control in erpimage_spec.controls[:10]] == [ - None, - "index", - None, - None, - None, - None, - "projchan", - "projchan_button", - None, - "title", - ] - - comperp_spec = pop_comperp_dialog_spec([sample_eeg], flag=1) - assert comperp_spec.row_spacing == 8 - assert [control.string for control in comperp_spec.controls[2:5]] == ["avg.", "std.", "all ERPs"] - assert comperp_spec.geometry[0] == comperp_spec.geometry[1] - assert comperp_spec.geometry[-1] == (1.48, 1.03, 1) - - plottopo_controls = controls_by_tag(pop_plottopo_dialog_spec(sample_eeg)) - assert plottopo_controls["rect"].value is False - assert plottopo_controls["options"].value == "'ydir', -1" - - chanplot_controls = controls_by_tag(pop_chanplot_dialog_spec({"name": "demo study"}, [sample_eeg])) - assert chanplot_controls["chan_list"].string.startswith("All channels|") - assert chanplot_controls["chan_onechan"].string == "All subjects" - assert chanplot_controls["plot_chan_erp"].callback is not None - assert chanplot_controls["plot_chan_erp"].callback.params["value"] == "erp" - assert chanplot_controls["plot_chan_erpimage"].enabled is False - - headplot_controls = controls_by_tag(pop_headplot_dialog_spec(sample_eeg, typeplot=1)) - headplot_spec = pop_headplot_dialog_spec(sample_eeg, typeplot=1) - assert headplot_spec.size == (1290, 547) - assert headplot_spec.content_margins == (42, 35, 42, 13) - assert headplot_controls["loadcb"].callback.name == "headplot_setup_mode" - assert headplot_controls["compcb"].callback.name == "headplot_setup_mode" - assert headplot_controls["setup_browse"].callback.name == "select_file" - np.testing.assert_allclose( - default_headplot_mesh_transform("mheadnew.mat"), [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100] - ) - assert headplot_controls["meshfile"].callback.params["transform_choices"] == ( - "0 -10 0 -0.1 0 -1.6 1100 1100 1100", - "0 -15 -15 0.05 0 -1.57 100 88 110", - ) - assert headplot_controls["mesh_browse"].callback.params["custom_transform"] == "0 -10 0 -0.1 0 -1.6 1100 1100 1100" - assert headplot_controls["manual_coreg"].callback.name == "headplot_manual_coreg" - assert headplot_controls["manual_coreg"].callback.params["mesh_source"] == "meshfile" - assert headplot_controls["transform"].enabled is True - - -def test_headplot_setup_mode_checkboxes_behave_like_radio_buttons(): - class Toggle: - def __init__(self, checked): - self.checked = checked - self.signals_blocked = False - - def isChecked(self): - return self.checked - - def setChecked(self, value): - self.checked = bool(value) - - def blockSignals(self, value): - self.signals_blocked = bool(value) - - def setEnabled(self, _enabled): - pass - - load = Toggle(False) - comp = Toggle(True) - widgets = { - "loadcb": load, - "compcb": comp, - "load": Toggle(False), - "setup_file": Toggle(True), - } - - QtDialogRenderer._set_headplot_setup_mode( - widgets, - { - "source": "compcb", - "mode": "compute", - "peer": "loadcb", - "load_targets": ("load",), - "setup_targets": ("setup_file",), - }, - False, - ) - - assert comp.isChecked() is True - assert load.isChecked() is False - - -def test_headplot_mesh_choice_restores_each_template_transform(sample_eeg): - class Text: - def __init__(self): - self.value = "" - - def setText(self, value): - self.value = value - - class Combo: - def __init__(self): - self.index = 0 - - def setCurrentIndex(self, index): - self.index = index - - def count(self): - return 2 - - spec = pop_headplot_dialog_spec(sample_eeg, typeplot=1) - params = controls_by_tag(spec)["meshfile"].callback.params - transform = Text() - reference = Combo() - - QtDialogRenderer._set_headplot_mesh_choice( - {"meshchanfile": reference, "transform": transform}, - params, - 1, - ) - assert transform.value == "0 -15 -15 0.05 0 -1.57 100 88 110" - assert reference.index == 1 - - QtDialogRenderer._set_headplot_mesh_choice( - {"meshchanfile": reference, "transform": transform}, - params, - 0, - ) - assert transform.value == "0 -10 0 -0.1 0 -1.6 1100 1100 1100" - assert reference.index == 0 - - -def test_headplot_mesh_browse_resets_to_generic_transform(monkeypatch, sample_eeg, tmp_path): - class FileDialog: - @staticmethod - def getOpenFileName(*_args): - return str(tmp_path / "custom.mat"), "" - - class QtWidgets: - QFileDialog = FileDialog - - class Combo: - def __init__(self): - self.value = None - self.text = "" - - def setEditable(self, _editable): - pass - - def setEditText(self, value): - self.text = value - - def setProperty(self, _name, value): - self.value = value - - class Text: - def __init__(self): - self.value = "" - - def setText(self, value): - self.value = value - - monkeypatch.setattr("eegprep.functions.guifunc.qt._require_qt", lambda: (None, QtWidgets)) - params = controls_by_tag(pop_headplot_dialog_spec(sample_eeg, typeplot=1))["mesh_browse"].callback.params - target = Combo() - transform = Text() - - QtDialogRenderer._select_file(object(), target, params, {"transform": transform}) - - assert target.value == str(tmp_path / "custom.mat") - assert transform.value == "0 -10 0 -0.1 0 -1.6 1100 1100 1100" - - -def test_headplot_manual_coreg_callback_writes_transform(monkeypatch, sample_eeg): - class Text: - def __init__(self, value=""): - self.value = value - - def text(self): - return self.value - - def setText(self, value): - self.value = value - - class Combo: - def __init__(self, index=0): - self.index = index - - def currentIndex(self): - return self.index - - def property(self, _name): - return None - - calls = [] - - def fake_coregister(*args, **kwargs): - calls.append((args, kwargs)) - return [1, 2, 3, 0.1, 0.2, 0.3, 4, 5, 6] - - monkeypatch.setattr("eegprep.functions.guifunc.coregister.run_coregister_dialog", fake_coregister) - params = controls_by_tag(pop_headplot_dialog_spec(sample_eeg, typeplot=1))["manual_coreg"].callback.params - transform = Text("0 -10 0 -0.1 0 -1.6 1100 1100 1100") - - QtDialogRenderer._run_headplot_manual_coreg( - object(), - {"meshfile": Combo(0), "meshchanfile": Combo(0), "transform": transform}, - params, - ) - - assert transform.value == "1 2 3 0.1 0.2 0.3 4 5 6" - assert calls[0][1]["meshfile"] == "mheadnew.mat" - assert calls[0][1]["transform"] == "0 -10 0 -0.1 0 -1.6 1100 1100 1100" - - -def test_pop_plottopo_rect_option_switches_layout(sample_epoch): - eeg = deepcopy(sample_epoch) - eeg["data"] = eeg["data"][:4] - eeg["nbchan"] = 4 - eeg["chanlocs"] = [ - {"labels": "Fz", "theta": 0, "radius": 0.5}, - {"labels": "C4", "theta": 90, "radius": 0.5}, - {"labels": "Pz", "theta": 180, "radius": 0.5}, - {"labels": "C3", "theta": -90, "radius": 0.5}, - ] - - topo_fig, topo_command = pop_plottopo(eeg, chans=[1, 2, 3, 4], return_com=True) - rect_fig, rect_command = pop_plottopo(eeg, chans=[1, 2, 3, 4], rect=True, return_com=True) - - topo_positions = [axis.get_position().bounds for axis in topo_fig.axes] - rect_positions = [axis.get_position().bounds for axis in rect_fig.axes] - assert topo_positions != rect_positions - assert "rect=True" in rect_command - _assert_python_command(topo_command) - _assert_python_command(rect_command) - plt.close(topo_fig) - plt.close(rect_fig) - - -def test_component_activations_use_icachansind_subset(ica_epoch): - eeg = deepcopy(ica_epoch) - eeg["icaact"] = None - eeg["icachansind"] = np.array([1, 3]) - eeg["icaweights"] = np.eye(2) - eeg["icasphere"] = np.eye(2) - eeg["icawinv"] = np.eye(2) - - activations = component_activations(eeg) - - np.testing.assert_allclose(activations[0], eeg["data"][1]) - np.testing.assert_allclose(activations[1], eeg["data"][3]) - - -def test_component_map_plots_use_icachansind_subset(ica_epoch): - eeg = deepcopy(ica_epoch) - eeg["icaact"] = None - eeg["icachansind"] = np.array([1, 3]) - eeg["icaweights"] = np.eye(2) - eeg["icasphere"] = np.eye(2) - eeg["icawinv"] = np.eye(2) - - figures, topoplot_command = pop_topoplot(eeg, typeplot=0, items=[1], colorbar="off", return_com=True) - prop_figure, prop_command = pop_prop(eeg, typecomp=0, chanorcomp=1, return_com=True) - stat_result, stat_command = pop_signalstat(eeg, typeproc=0, cnum=1, return_com=True) - - assert len(figures) == 1 - assert prop_figure is not None - assert stat_result.figure is not None - _assert_python_command(topoplot_command) - _assert_python_command(prop_command) - _assert_python_command(stat_command) - plt.close(figures[0]) - plt.close(prop_figure) - plt.close(stat_result.figure) - - -def test_pop_envtopo_uses_icachansind_subset_and_rejects_multiple(ica_epoch): - eeg = deepcopy(ica_epoch) - eeg["icaact"] = None - eeg["icachansind"] = np.array([1, 3]) - eeg["icaweights"] = np.eye(2) - eeg["icasphere"] = np.eye(2) - eeg["icawinv"] = np.eye(2) - - figure, command = pop_envtopo(eeg, components=[1], return_com=True) - - assert len(figure.axes) >= 2 - _assert_python_command(command) - plt.close(figure) - with pytest.raises(ValueError, match="one dataset"): - pop_envtopo([ica_epoch, deepcopy(ica_epoch)], components=[1]) - - -def test_pop_comperp_and_chanplot_work_on_epoched_dataset_lists(sample_epoch): - second = deepcopy(sample_epoch) - second["setname"] = "second" - - comperp_result, comperp_command = pop_comperp([sample_epoch, second], flag=1, datadd=[1, 2], return_com=True) - study, chanplot_command, chanplot_fig = pop_chanplot( - {"name": "demo study"}, [sample_epoch, second], channels=[1], return_com=True - ) - - assert comperp_result["erp1"].shape[1] == sample_epoch["pnts"] - assert study["etc"]["last_chanplot"]["channels"] == [1] - _assert_python_command(comperp_command) - _assert_python_command(chanplot_command) - plt.close(comperp_result["figure"]) - plt.close(chanplot_fig) - - -def test_pop_chanplot_gui_filters_channels(sample_epoch): - class Renderer: - def __init__(self): - self.spec = None - - def run(self, spec, initial_values=None): - self.spec = spec - return {"channels": "1 2", "measure": 1} - - second = deepcopy(sample_epoch) - renderer = Renderer() - study, command, figure = pop_chanplot( - {"name": "demo study"}, [sample_epoch, second], gui=True, renderer=renderer, return_com=True - ) - - assert renderer.spec is not None - assert study["etc"]["last_chanplot"]["channels"] == [1, 2] - assert "channels=[1, 2]" in command - assert "measure='erp'" in command - _assert_python_command(command) - plt.close(figure) - - -def test_pop_comperp_rms_mode_and_grid_validation(sample_epoch): - first = deepcopy(sample_epoch) - second = deepcopy(sample_epoch) - first["data"] = np.ones_like(first["data"]) - second["data"] = -np.ones_like(second["data"]) - - ave, _command = pop_comperp([first, second], flag=1, datadd=[1, 2], mode="ave", return_com=True) - rms, _command = pop_comperp([first, second], flag=1, datadd=[1, 2], mode="rms", return_com=True) - - np.testing.assert_allclose(ave["erp1"], 0) - np.testing.assert_allclose(rms["erp1"], 1) - plt.close(ave["figure"]) - plt.close(rms["figure"]) - second["xmax"] = float(second["xmax"]) + 0.1 - with pytest.raises(ValueError, match="time grid"): - pop_comperp([first, second], flag=1, datadd=[1, 2]) - - -def test_pop_comperp_supports_display_options_and_significance(sample_epoch): - datasets = [deepcopy(sample_epoch) for _ in range(4)] - scales = [1.0, 1.12, 0.78, 1.31] - for offset, dataset in enumerate(datasets): - dataset["data"] = np.asarray(dataset["data"], dtype=float) * scales[offset] + offset * 0.05 - dataset["setname"] = f"dataset {offset + 1}" - - result, command = pop_comperp( - datasets, - flag=1, - datadd=[3, 4], - datsub=[1, 2], - chans=[1, 2], - alpha=0.05, - addstd="on", - substd="on", - diffstd="on", - addall="on", - suball="on", - diffall="on", - tlim=[-50, 100], - ylim=[-20, 20], - title="Compared ERPs", - return_com=True, - ) - - assert result["erp1"].shape == (2, sample_epoch["pnts"]) - assert result["erp2"].shape == (2, sample_epoch["pnts"]) - assert result["erpsub"].shape == (2, sample_epoch["pnts"]) - assert result["pvalues"].shape == (2, sample_epoch["pnts"]) - assert np.isfinite(result["pvalues"]).any() - assert "addstd='on'" in command - _assert_python_command(command) - plt.close(result["figure"]) - - with pytest.raises(ValueError, match="unsupported option"): - pop_comperp(datasets, flag=1, datadd=[1, 2], unsupported="on") - - -def test_pop_comperp_significance_shading_marks_known_effect(sample_epoch): - datasets = [] - for amplitude in (1.0, 1.1, 0.9): - dataset = deepcopy(sample_epoch) - dataset["data"] = np.zeros_like(np.asarray(dataset["data"], dtype=float)) + amplitude - datasets.append(dataset) - for _index in range(3): - dataset = deepcopy(sample_epoch) - dataset["data"] = np.zeros_like(np.asarray(dataset["data"], dtype=float)) - datasets.append(dataset) - - result = pop_comperp(datasets, flag=1, datadd=[1, 2, 3], datsub=[4, 5, 6], chans=[1, 2], alpha=0.01) - - significant_patches = [ - patch for patch in result["figure"].axes[0].patches if patch.get_alpha() == pytest.approx(0.18) - ] - assert significant_patches - plt.close(result["figure"]) - - -def test_pop_chanplot_validates_time_grid(sample_epoch): - second = deepcopy(sample_epoch) - second["xmax"] = float(second["xmax"]) + 0.1 - - with pytest.raises(ValueError, match="time grid"): - pop_chanplot({"name": "demo study"}, [sample_epoch, second], channels=[1]) - - -def test_pop_erpimage_applies_time_limits_and_decimation(sample_epoch): - result, command = pop_erpimage( - sample_epoch, - typeplot=1, - index=1, - limits=[-50, 100], - decimate=2, - caxis=[-1, 1], - cbar=False, - return_com=True, - ) - - expected_samples = np.count_nonzero((sample_epoch["times"] >= -50) & (sample_epoch["times"] <= 100)) - assert result["image"].shape[1] == expected_samples - assert result["image"].shape[0] == int(np.ceil(sample_epoch["trials"] / 2)) - assert "limits=[-50, 100]" in command - _assert_python_command(command) - plt.close(result["figure"]) - - -def test_pop_erpimage_sorts_by_epoch_event_field_and_limits(sample_epoch): - eeg = deepcopy(sample_epoch) - eeg["data"] = np.asarray( - [ - [ - [1.0, 2.0, 3.0], - [1.0, 2.0, 3.0], - [1.0, 2.0, 3.0], - [1.0, 2.0, 3.0], - ] - ] - ) - eeg["nbchan"] = 1 - eeg["pnts"] = 4 - eeg["trials"] = 3 - eeg["srate"] = 100.0 - eeg["xmin"] = 0.0 - eeg["xmax"] = 0.03 - eeg["times"] = np.asarray([0.0, 10.0, 20.0, 30.0]) - eeg["event"] = [ - {"type": "rt", "latency": 2, "epoch": 1, "rt": 30}, - {"type": "rt", "latency": 6, "epoch": 2, "rt": 10}, - {"type": "rt", "latency": 10, "epoch": 3, "rt": 20}, - ] - - result, command = pop_erpimage( - eeg, - typeplot=1, - index=1, - sortingeventfield="rt", - sortingtype=["rt"], - sortingwin=[0, 20], - return_com=True, - ) - unsorted, _command = pop_erpimage(eeg, typeplot=1, index=1, sort_values=[30, 10, 20], nosort=True, return_com=True) - - np.testing.assert_allclose(result["image"][:, 0], [2, 3, 1]) - np.testing.assert_allclose(unsorted["image"][:, 0], [1, 2, 3]) - assert "sortingeventfield='rt'" in command - _assert_python_command(command) - plt.close(result["figure"]) - plt.close(unsorted["figure"]) - - with pytest.raises(ValueError, match="standalone ERP image"): - pop_erpimage(eeg, typeplot=1, index=1, align=[0]) - - -def test_plot_history_preserves_effective_options(sample_epoch, ica_epoch): - timtopo_fig, timtopo_command = pop_timtopo( - sample_epoch, - plottimes=[0], - timerange=[-50, 100], - winsize=[10], - title="custom timtopo", - return_com=True, - ) - plottopo_fig, plottopo_command = pop_plottopo( - sample_epoch, - chans=[1], - timerange=[-50, 100], - title="custom plottopo", - return_com=True, - ) - envtopo_fig, envtopo_command = pop_envtopo( - ica_epoch, - timerange=[0, 100], - components=[1], - title="custom envtopo", - return_com=True, - ) - - assert "timerange=[-50, 100]" in timtopo_command - assert "winsize=[10]" in timtopo_command - assert "title='custom timtopo'" in timtopo_command - assert "timerange=[-50, 100]" in plottopo_command - assert "title='custom plottopo'" in plottopo_command - assert "components=[1]" in envtopo_command - assert "title='custom envtopo'" in envtopo_command - for command in (timtopo_command, plottopo_command, envtopo_command): - _assert_python_command(command) - plt.close(timtopo_fig) - plt.close(plottopo_fig) - plt.close(envtopo_fig) - - -def test_pop_spectopo_component_path_plots_component_maps(ica_epoch): - result, command = pop_spectopo( - ica_epoch, - dataflag=0, - freqs=[10], - icacomps=[1, 2], - icamaps=[1], - return_com=True, - ) - - assert result["figure"] is not None - assert len(result["figure"].axes) >= 2 - assert result["specstd"] is None - assert "icamaps=[1]" in command - _assert_python_command(command) - plt.close(result["figure"]) - - -def test_plot_wrappers_fail_clearly_when_required_fields_are_missing(sample_epoch, tmp_path): - with pytest.raises(ValueError, match="ICA"): - pop_plotdata(sample_epoch, components=[1]) - continuous = deepcopy(sample_epoch) - continuous["trials"] = 1 - continuous["data"] = continuous["data"][:, :, 0] - with pytest.raises(ValueError, match="epoched"): - pop_erpimage(continuous, typeplot=1, index=1) - splinefile = tmp_path / "epoch.spl" - headplot_setup( - sample_epoch["chanlocs"], - splinefile, - chaninfo=sample_epoch["chaninfo"], - transform=[0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100], - ) - epoched = deepcopy(sample_epoch) - epoched["splinefile"] = str(splinefile) - with pytest.raises(ValueError, match="outside the epoch time range"): - pop_headplot(epoched, typeplot=1, items=[1e6]) - with pytest.raises(ValueError, match="cannot find spline file"): - pop_headplot(sample_epoch, typeplot=1, items=[0]) - no_locs = deepcopy(sample_epoch) - no_locs["chanlocs"] = [] - with pytest.raises(ValueError, match="does not contain channel locations"): - pop_headplot(no_locs, typeplot=1, items=[0]) - - -def _assert_python_command(command: str) -> None: - ast.parse(command) - - -def test_component_activations_dedup_contract(): - """Lock the K4 dedup: rejection delegates recompute to the canonical helper. - - The rejection ``component_activations`` (``_rejection``) and the canonical - plotting helper (``_plot_utils``) must agree when recomputing from weights, - and rejection must ignore a stored ``icaact`` while plotting trusts it. - """ - from eegprep.functions.popfunc._rejection import component_activations as rejection_activations - - rng = np.random.default_rng(7) - nbchan, pnts, trials = 5, 16, 4 - data = rng.standard_normal((nbchan, pnts, trials)) - weights = rng.standard_normal((nbchan, nbchan)) - sphere = rng.standard_normal((nbchan, nbchan)) - recompute = (weights @ sphere) @ data.reshape(nbchan, -1, order="F") - stored = -recompute.reshape(nbchan, pnts, trials, order="F") - eeg = { - "data": data, - "icaweights": weights, - "icasphere": sphere, - "icachansind": np.arange(nbchan), - "nbchan": nbchan, - "pnts": pnts, - "trials": trials, - "icaact": stored, - } - - plot_recompute = component_activations(eeg, use_stored=False) - assert np.allclose(rejection_activations(eeg), plot_recompute) - # Rejection ignores the stored icaact; the default plotting path trusts it. - assert not np.allclose(rejection_activations(eeg), stored) - assert np.allclose(component_activations(eeg), stored) - - -def _matlab_string(path: Any) -> str: - return str(path).replace("'", "''") - - -def _matlab_vector(values: list[float]) -> str: - return " ".join(str(value) for value in values) - - -def _eeglab_reference_root() -> Path: - repo_root = Path(__file__).resolve().parents[1] - package_reference = repo_root / "src" / "eegprep" / "eeglab" - if (package_reference / "functions" / "popfunc" / "pop_headplot.m").exists(): - return package_reference - sibling_reference = repo_root.parent / "eeglab" - if (sibling_reference / "functions" / "popfunc" / "pop_headplot.m").exists(): - return sibling_reference - return package_reference diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py deleted file mode 100644 index 55d10f70..00000000 --- a/tests/test_pipeline.py +++ /dev/null @@ -1,316 +0,0 @@ -""" -Test suite for pipeline: clean_artifacts -> eeg_picard -> iclabel -""" - -import os -import re -import unittest -from copy import deepcopy -from eegprep import pop_loadset, clean_artifacts, eeg_picard, iclabel -from eegprep.functions.adminfunc.eeglabcompat import get_eeglab -from eegprep.functions.popfunc.eeg_compare import eeg_compare -from eegprep.utils.testing import ( - compare_eeg, - DebuggableTestCase, - has_optional_dependency, - matlab_function_exists, -) - - -@unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") -def test_pipeline(): - """Test pipeline: clean_artifacts -> eeg_picard -> iclabel, comparing Python and MATLAB at each step.""" - # where the test resources - local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - fname = os.path.join(local_url, 'eeglab_data_with_ica_tmp.set') - EEG = pop_loadset(fname) - - # --- Step 1: clean_artifacts (channel cleaning only) --- - # Channel cleaning only: BurstCriterion='off' - EEG_py_ch, *_ = clean_artifacts(deepcopy(EEG), BurstCriterion='off', ChannelCriterion=0.8) - eeglab = get_eeglab('MAT') - EEG_mat_ch = eeglab.clean_artifacts(deepcopy(EEG), 'BurstCriterion', 'off', 'ChannelCriterion', 0.8) - # eeg_compare(EEG_py_ch, EEG_mat_ch) - compare_eeg( - EEG_py_ch['data'], - EEG_mat_ch['data'], - rtol=0.005, - atol=1e-5, - err_msg='clean_artifacts() channel cleaning Python vs MATLAB failed', - ) - print("clean_artifacts() channel cleaning Python vs MATLAB passed\n\n\n") - - -@unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") -class TestPipeline(DebuggableTestCase): - """Test pipeline: clean_artifacts -> eeg_picard -> iclabel, comparing Python and MATLAB at each step.""" - - def setUp(self): - """Set up test fixtures.""" - local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - fname = os.path.join(local_url, 'eeglab_data_with_ica_tmp.set') - self.EEG = pop_loadset(fname) - self.eeglab = get_eeglab('MAT') - self.has_matlab_picard = matlab_function_exists(self.eeglab, 'eeg_picard') - - def test_clean_artifacts_channel_cleaning(self): - """Test clean_artifacts channel cleaning step (BurstCriterion='off').""" - # Channel cleaning only: BurstCriterion='off' - # Use deepcopy to ensure Python doesn't modify the original EEG - EEG_py_ch, *_ = clean_artifacts(deepcopy(self.EEG), BurstCriterion='off', ChannelCriterion=0.8) - # MATLAB also needs a fresh copy since it may modify the EEG structure - EEG_mat_ch = self.eeglab.clean_artifacts(deepcopy(self.EEG), 'BurstCriterion', 'off', 'ChannelCriterion', 0.8) - - print("\n" + "=" * 80) - print("Step 1: clean_artifacts (channel cleaning only)") - print("=" * 80) - summary = compare_eeg( - EEG_py_ch['data'], - EEG_mat_ch['data'], - rtol=0.005, - atol=1e-5, - err_msg='clean_artifacts() channel cleaning Python vs MATLAB failed', - ) - print(summary) - print("=" * 80 + "\n") - - def test_clean_artifacts_burst_cleaning(self): - """Test clean_artifacts burst cleaning step (ChannelCriterion='off').""" - # First do channel cleaning - EEG_py_ch, *_ = clean_artifacts(deepcopy(self.EEG), BurstCriterion='off', ChannelCriterion=0.8) - EEG_mat_ch = self.eeglab.clean_artifacts(deepcopy(self.EEG), 'BurstCriterion', 'off', 'ChannelCriterion', 0.8) - - # Then do burst cleaning only: ChannelCriterion='off' - EEG_py, *_ = clean_artifacts(EEG_py_ch, ChannelCriterion='off') - EEG_mat = self.eeglab.clean_artifacts(EEG_mat_ch, 'ChannelCriterion', 'off', 'BurstCriterion', 5.0) - - print("\n" + "=" * 80) - print("Step 1b: clean_artifacts (burst cleaning only)") - print("=" * 80) - eeg_summary = eeg_compare(EEG_py, EEG_mat) - print(f"\n{eeg_summary}") - data_summary = compare_eeg( - EEG_py['data'], - EEG_mat['data'], - rtol=0.005, - atol=1e-5, - err_msg='clean_artifacts() burst cleaning Python vs MATLAB failed', - ) - print(f"\n{data_summary}") - print("=" * 80 + "\n") - - def test_eeg_picard(self): - """Test eeg_picard ICA decomposition.""" - if not self.has_matlab_picard: - self.skipTest("MATLAB EEGLAB Picard plugin is not installed") - - # Prepare data: channel cleaning + burst cleaning - EEG_py_ch, *_ = clean_artifacts(deepcopy(self.EEG), BurstCriterion='off', ChannelCriterion=0.8) - EEG_mat_ch = self.eeglab.clean_artifacts(deepcopy(self.EEG), 'BurstCriterion', 'off', 'ChannelCriterion', 0.8) - EEG_py, *_ = clean_artifacts(EEG_py_ch, ChannelCriterion='off') - EEG_mat = self.eeglab.clean_artifacts(EEG_mat_ch, 'ChannelCriterion', 'off', 'BurstCriterion', 5.0) - - # Run ICA - EEG_py_ica = eeg_picard(EEG_py) - EEG_mat_ica = eeg_picard(EEG_mat, engine=self.eeglab) - - # Compare ICA fields - print("\n" + "=" * 80) - print("Step 2: eeg_picard (ICA decomposition)") - print("=" * 80) - for field in ['icaweights', 'icasphere', 'icawinv', 'icaact', 'icachansind']: - self.assertIn(field, EEG_py_ica, f"Missing ICA field in Python: {field}") - self.assertIn(field, EEG_mat_ica, f"Missing ICA field in MATLAB: {field}") - - print("\nComparing icaweights:") - weights_summary = eeg_compare(EEG_py_ica['icaweights'], EEG_mat_ica['icaweights']) - print(weights_summary) - - print("\nComparing icasphere:") - sphere_summary = eeg_compare(EEG_py_ica['icasphere'], EEG_mat_ica['icasphere']) - print(sphere_summary) - - print("\nComparing icawinv:") - winv_summary = eeg_compare(EEG_py_ica['icawinv'], EEG_mat_ica['icawinv']) - print(winv_summary) - print("=" * 80 + "\n") - - def test_iclabel(self): - """Test iclabel component classification.""" - if not self.has_matlab_picard: - self.skipTest("MATLAB EEGLAB Picard plugin is not installed") - if not has_optional_dependency('torch'): - self.skipTest("PyTorch is not installed; install eegprep[torch] to run ICLabel parity") - - # Prepare data: channel cleaning + burst cleaning + ICA - EEG_py_ch, *_ = clean_artifacts(deepcopy(self.EEG), BurstCriterion='off', ChannelCriterion=0.8) - EEG_mat_ch = self.eeglab.clean_artifacts(deepcopy(self.EEG), 'BurstCriterion', 'off', 'ChannelCriterion', 0.8) - EEG_py, *_ = clean_artifacts(EEG_py_ch, ChannelCriterion='off') - EEG_mat = self.eeglab.clean_artifacts(EEG_mat_ch, 'ChannelCriterion', 'off', 'BurstCriterion', 5.0) - EEG_py_ica = eeg_picard(EEG_py) - EEG_mat_ica = eeg_picard(EEG_mat, engine=self.eeglab) - - # Run ICLabel - EEG_py_lbl = iclabel(EEG_py_ica) - EEG_mat_lbl = iclabel(EEG_mat_ica, engine='matlab') - - # Check ICLabel output structure - print("\n" + "=" * 80) - print("Step 3: iclabel (component classification)") - print("=" * 80) - for EEG_lbl in [EEG_py_lbl, EEG_mat_lbl]: - self.assertIn('etc', EEG_lbl, 'Missing etc field') - self.assertIn('ic_classification', EEG_lbl['etc'], 'Missing ic_classification field') - self.assertIn('ICLabel', EEG_lbl['etc']['ic_classification'], 'Missing ICLabel field') - - res_py = EEG_py_lbl['etc']['ic_classification']['ICLabel']['classifications'].flatten() - res_mat = EEG_mat_lbl['etc']['ic_classification']['ICLabel']['classifications'].flatten() - print("\nComparing ICLabel classifications:") - iclabel_summary = eeg_compare(res_py, res_mat) - print(iclabel_summary) - print("=" * 80 + "\n") - - def test_z_full_pipeline(self): - """Test the complete pipeline end-to-end.""" - if not self.has_matlab_picard: - self.skipTest("MATLAB EEGLAB Picard plugin is not installed") - if not has_optional_dependency('torch'): - self.skipTest("PyTorch is not installed; install eegprep[torch] to run full pipeline parity") - - print("\n" + "=" * 80) - print("Full Pipeline Test: clean_artifacts -> eeg_picard -> iclabel") - print("=" * 80) - - # Run the pipeline once and collect all eeg_compare summaries - summaries = {} - - # Step 1: Channel cleaning - EEG_py_ch, *_ = clean_artifacts(deepcopy(self.EEG), BurstCriterion='off', ChannelCriterion=0.8) - EEG_mat_ch = self.eeglab.clean_artifacts(deepcopy(self.EEG), 'BurstCriterion', 'off', 'ChannelCriterion', 0.8) - data_summary_1 = compare_eeg( - EEG_py_ch['data'], - EEG_mat_ch['data'], - rtol=0.005, - atol=1e-5, - err_msg='clean_artifacts() channel cleaning Python vs MATLAB failed', - ) - print(f"\nStep 1 - Channel cleaning data comparison:\n{data_summary_1}") - - # Step 1b: Burst cleaning - EEG_py, *_ = clean_artifacts(EEG_py_ch, ChannelCriterion='off') - EEG_mat = self.eeglab.clean_artifacts(EEG_mat_ch, 'ChannelCriterion', 'off', 'BurstCriterion', 5.0) - summaries['burst_cleaning_eeg'] = eeg_compare(EEG_py, EEG_mat) - data_summary_1b = compare_eeg( - EEG_py['data'], - EEG_mat['data'], - rtol=0.005, - atol=1e-5, - err_msg='clean_artifacts() burst cleaning Python vs MATLAB failed', - ) - print(f"\nStep 1b - Burst cleaning data comparison:\n{data_summary_1b}") - - # Step 2: ICA - EEG_py_ica = eeg_picard(EEG_py) - EEG_mat_ica = eeg_picard(EEG_mat, engine=self.eeglab) - summaries['icaweights'] = eeg_compare(EEG_py_ica['icaweights'], EEG_mat_ica['icaweights']) - summaries['icasphere'] = eeg_compare(EEG_py_ica['icasphere'], EEG_mat_ica['icasphere']) - summaries['icawinv'] = eeg_compare(EEG_py_ica['icawinv'], EEG_mat_ica['icawinv']) - - # Step 3: ICLabel - EEG_py_lbl = iclabel(EEG_py_ica) - EEG_mat_lbl = iclabel(EEG_mat_ica, engine='matlab') - res_py = EEG_py_lbl['etc']['ic_classification']['ICLabel']['classifications'].flatten() - res_mat = EEG_mat_lbl['etc']['ic_classification']['ICLabel']['classifications'].flatten() - summaries['iclabel_classifications'] = eeg_compare(res_py, res_mat) - - # Print consolidated eeg_compare summary as a table - print("\n" + "=" * 80) - print("Full Pipeline Test - Consolidated eeg_compare Summary Table") - print("=" * 80) - - # Helper function to extract metrics from summary string - def extract_metrics(summary_str): - """Extract key metrics from summary string.""" - metrics = { - 'max_abs_diff': 'N/A', - 'mean_abs_diff': 'N/A', - 'rms_diff': 'N/A', - 'max_rel_diff': 'N/A', - 'mismatch_pct': 'N/A', - } - if 'Array Comparison Summary' in summary_str: - for line in summary_str.split('\n'): - if 'Max absolute difference' in line: - metrics['max_abs_diff'] = line.split(':')[1].strip() - elif 'Mean absolute difference' in line: - metrics['mean_abs_diff'] = line.split(':')[1].strip() - elif 'RMS difference' in line: - metrics['rms_diff'] = line.split(':')[1].strip() - elif 'Max relative difference' in line: - metrics['max_rel_diff'] = line.split(':')[1].strip() - elif 'Mismatched elements' in line and '%' in line: - # Extract percentage like "Mismatched elements (> 1e-10): 900 (100.00%)" - # Find the last occurrence of (X.XX%) pattern - match = re.search(r'\(([\d.]+)%\)', line) - if match: - metrics['mismatch_pct'] = match.group(1) + '%' - else: - metrics['mismatch_pct'] = 'N/A' - elif 'Found' in summary_str and 'differences' in summary_str: - # For EEG structure comparisons - metrics['max_abs_diff'] = 'See details' - metrics['mismatch_pct'] = summary_str.split('Found')[1].split('total')[0].strip() + ' diff' - return metrics - - # Organize summaries by step - step_data = [] - - # Step 1: clean_artifacts (use burst_cleaning_eeg as it's the final output) - if 'burst_cleaning_eeg' in summaries: - step1_summary = summaries['burst_cleaning_eeg'] - step1_metrics = extract_metrics(step1_summary) - step_data.append(('Step 1: clean_artifacts', step1_metrics, step1_summary)) - - # Step 2: eeg_picard (combine ICA arrays - show all) - ica_arrays = ['icaweights', 'icasphere', 'icawinv'] - for idx, array_name in enumerate(ica_arrays): - if array_name in summaries: - array_summary = summaries[array_name] - array_metrics = extract_metrics(array_summary) - if idx == 0: - step_name = 'Step 2: eeg_picard' - else: - step_name = '' - step_data.append((f' {array_name}' if idx > 0 else step_name, array_metrics, array_summary)) - - # Step 3: iclabel - if 'iclabel_classifications' in summaries: - step3_summary = summaries['iclabel_classifications'] - step3_metrics = extract_metrics(step3_summary) - step_data.append(('Step 3: iclabel', step3_metrics, step3_summary)) - - # Print table - print( - f"\n{'Step':<30} {'Max Abs Diff':<18} {'Mean Abs Diff':<18} {'RMS Diff':<18} {'Max Rel Diff':<18} {'Mismatch %':<15}" - ) - print("-" * 120) - - for step_name, metrics, _ in step_data: - print( - f"{step_name:<30} {metrics['max_abs_diff']:<18} {metrics['mean_abs_diff']:<18} " - f"{metrics['rms_diff']:<18} {metrics['max_rel_diff']:<18} {metrics['mismatch_pct']:<15}" - ) - - print("-" * 120) - print("\nDetailed summaries:") - for step_name, _, summary in step_data: - print(f"\n{step_name}:") - print(summary) - - print("\n" + "=" * 80) - print("Full pipeline test completed successfully!") - print("=" * 80 + "\n") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_plugin_menu.py b/tests/test_plugin_menu.py index 95f005f2..d9f03dfd 100644 --- a/tests/test_plugin_menu.py +++ b/tests/test_plugin_menu.py @@ -35,7 +35,7 @@ def test_bundled_plugins_describe_in_repo_extensions() -> None: plugins = bundled_plugins() names = [plugin["plugin"] for plugin in plugins] - assert names == ["clean_rawdata", "ICLabel", "firfilt", "dipfit", "EEG_BIDS"] + assert names == ["firfilt", "dipfit", "EEG_BIDS"] assert all(plugin["installed"] is True for plugin in plugins) assert all(plugin["status"] == "ok" for plugin in plugins) assert all(plugin["source"] == "bundled" for plugin in plugins) @@ -77,16 +77,16 @@ def test_bundled_plugins_returns_copies() -> None: def test_plugin_status_supports_partial_and_exact_matches() -> None: - partial_status, partial_names, partial_struct = plugin_status("label") - exact_status, exact_names, exact_struct = plugin_status("ICLabel", exactmatch=True) + partial_status, partial_names, partial_struct = plugin_status("filt") + exact_status, exact_names, exact_struct = plugin_status("firfilt", exactmatch=True) missing_status, missing_names, missing_struct = plugin_status("external") assert partial_status == [1] - assert partial_names == ["ICLabel"] - assert partial_struct[0]["name"] == "ICLabel" + assert partial_names == ["firfilt"] + assert partial_struct[0]["name"] == "firfilt" assert exact_status == [1] - assert exact_names == ["ICLabel"] - assert exact_struct[0]["foldername"] == "ICLabel" + assert exact_names == ["firfilt"] + assert exact_struct[0]["foldername"] == "firfilt" assert missing_status == [] assert missing_names == [] assert missing_struct == [] @@ -98,20 +98,18 @@ def test_plugin_menu_updates_session_without_gui() -> None: assert session.PLUGINLIST == plugins assert [plugin["plugin"] for plugin in session.PLUGINLIST] == [ - "clean_rawdata", - "ICLabel", "firfilt", "dipfit", "EEG_BIDS", ] - assert [plugin["status"] for plugin in session.PLUGINLIST] == ["bundled"] * 5 + assert [plugin["status"] for plugin in session.PLUGINLIST] == ["bundled"] * 3 def test_format_plugin_menu_includes_external_plugin_exclusion() -> None: text = format_plugin_menu() assert "Available EEGPrep extensions" in text - assert "ICLabel" in text + assert "firfilt" in text assert "File > Import data / Export / BIDS tools" in text assert EXTERNAL_PLUGIN_NOTICE in text assert INSTALL_TRUST_WARNING in text @@ -124,8 +122,6 @@ def test_file_menu_plugin_action_uses_bundled_inventory_headlessly() -> None: dispatcher.dispatch("plugin_menu", parent=None) assert [plugin["plugin"] for plugin in session.PLUGINLIST] == [ - "clean_rawdata", - "ICLabel", "firfilt", "dipfit", "EEG_BIDS", diff --git a/tests/test_pop_prop_extended.py b/tests/test_pop_prop_extended.py deleted file mode 100644 index b5c2d399..00000000 --- a/tests/test_pop_prop_extended.py +++ /dev/null @@ -1,201 +0,0 @@ -import matplotlib - -matplotlib.use("Agg") - -import matplotlib.pyplot as plt -import numpy as np -import pytest - -from eegprep.plugins.ICLabel import _prop_browser, _prop_numerics -from eegprep.plugins.ICLabel import pop_prop_extended as pop_prop_extended_module -from eegprep.plugins.ICLabel.pop_prop_extended import ( - DEFAULT_ICLABEL_CLASSES, - build_extended_property_data, - classifier_name_from_gui, - classifier_names, - resolve_classifier_data, - component_rejection_status, - resolve_dipfit_data, - selected_property_indices, -) -from eegprep.functions.popfunc._rejection import component_rejection_flags, set_component_rejection_flag -from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops -from tests.fixtures import create_test_eeg_with_ica - - -def _iclabel_eeg(*, include_classifier: bool = True) -> dict: - rng = np.random.default_rng(8) - eeg = create_test_eeg_with_ica(n_channels=4, n_samples=120, srate=100.0, n_components=4, n_trials=2) - eeg["xmin"] = -0.1 - eeg["xmax"] = 1.09 - eeg["times"] = np.linspace(-100.0, 1090.0, 120) - eeg["data"] = rng.normal(0.0, 1.0, (4, 120, 2)) - eeg["icaweights"] = np.eye(4) - eeg["icasphere"] = np.eye(4) - eeg["icawinv"] = np.eye(4) - eeg["icaact"] = rng.normal(0.0, 0.5, (4, 120, 2)) - eeg["icachansind"] = np.arange(4) - eeg["event"] = [ - {"type": "stim", "latency": 20.0, "duration": 0.0}, - {"type": "resp", "latency": 80.0, "duration": 0.0}, - ] - eeg["reject"] = {"gcompreject": np.zeros(4, dtype=int)} - if include_classifier: - eeg["etc"] = { - "ic_classification": { - "Other": { - "classifications": np.full((4, 3), 1.0 / 3.0), - "classes": ["One", "Two", "Three"], - }, - "ICLabel": { - "classifications": np.array( - [ - [0.70, 0.10, 0.10, 0.03, 0.02, 0.03, 0.02], - [0.02, 0.94, 0.02, 0.01, 0.00, 0.00, 0.01], - [0.05, 0.02, 0.91, 0.01, 0.00, 0.00, 0.01], - [0.80, 0.05, 0.05, 0.02, 0.02, 0.03, 0.03], - ] - ), - }, - } - } - return eeg - - -def test_classifier_data_parsing_defaults_to_iclabel_and_standard_classes() -> None: - eeg = _iclabel_eeg() - - classifier = resolve_classifier_data(eeg, component_total=4, require=True) - - assert classifier.name == "ICLabel" - assert classifier.classes == DEFAULT_ICLABEL_CLASSES - np.testing.assert_allclose(classifier.probabilities[1], [0.02, 0.94, 0.02, 0.01, 0.0, 0.0, 0.01]) - assert classifier_names(eeg) == ["Other", "ICLabel"] - - -def test_channel_property_data_accepts_numpy_chanlocs() -> None: - eeg = _iclabel_eeg() - eeg["chanlocs"] = np.asarray(eeg["chanlocs"], dtype=object) - - dashboard = build_extended_property_data(eeg, 1, 1) - - assert len(dashboard.topography_chanlocs) == 4 - assert dashboard.topography_chanlocs[0]["labels"] == "Ch1" - - -def test_classifier_name_from_gui_matches_string_values_case_insensitively() -> None: - eeg = _iclabel_eeg() - - assert classifier_name_from_gui(eeg, "iclabel") == "ICLabel" - assert classifier_name_from_gui(eeg, "OTHER") == "Other" - - -def test_classifier_data_rejects_component_count_mismatch() -> None: - eeg = _iclabel_eeg() - - with pytest.raises(ValueError, match="rows for 3 ICA components"): - resolve_classifier_data(eeg, "ICLabel", component_total=3, require=True) - - -def test_selected_component_indices_are_eeglab_facing_one_based() -> None: - eeg = _iclabel_eeg() - - assert selected_property_indices(eeg, 0, "1:2") == [1, 2] - assert selected_property_indices(eeg, 0, [], default_all=True) == [1, 2, 3, 4] - with pytest.raises(ValueError, match="Index out of range"): - selected_property_indices(eeg, 0, [0]) - - -def test_component_rejection_flags_use_one_based_component_indices() -> None: - eeg = _iclabel_eeg() - eeg["reject"]["gcompreject"] = np.array([0, 1, 0, 0]) - - assert component_rejection_status(eeg, 2) is True - assert component_rejection_status(eeg, 1) is False - updated = set_component_rejection_flag(eeg, 3, True, 4) - - np.testing.assert_array_equal(updated, [0, 1, 1, 0]) - np.testing.assert_array_equal(eeg["reject"]["gcompreject"], [0, 1, 1, 0]) - - -def test_component_rejection_flags_initialize_missing_or_stale_vectors() -> None: - eeg = _iclabel_eeg() - eeg["reject"]["gcompreject"] = np.array([1, 1]) - - flags = component_rejection_flags(eeg, 4, create=True) - - np.testing.assert_array_equal(flags, [False, False, False, False]) - np.testing.assert_array_equal(eeg["reject"]["gcompreject"], [0, 0, 0, 0]) - - -def test_dashboard_data_assembly_includes_classification_surfaces() -> None: - eeg = _iclabel_eeg() - - dashboard = build_extended_property_data(eeg, 0, 2, spec_opt="'freqrange', [2 40]") - - assert dashboard.figure_title == "IC2 - pop_prop_extended()" - assert dashboard.topography_title == "IC2" - assert dashboard.activity_title == "Scrolling IC2 Activity" - assert dashboard.classifier is not None - assert dashboard.classifier.name == "ICLabel" - assert dashboard.class_probabilities is not None - assert dashboard.class_probabilities[1] == pytest.approx(0.94) - assert dashboard.spectrum_freqs.size == dashboard.spectrum_power.size - assert dashboard.image_data.size > 0 - assert dashboard.pvaf is None or np.isfinite(dashboard.pvaf) - assert dashboard.dipfit is None - assert dashboard.rejected is False - - -def test_dashboard_data_assembly_includes_localized_dipfit_model() -> None: - eeg = _iclabel_eeg() - eeg["dipfit"] = { - "coordformat": "MNI", - "model": [ - {"posxyz": [0, -20, 40], "momxyz": [1, 0, 0], "rv": 0.12, "component": 1}, - { - "posxyz": [[25, 10, 35], [-25, 10, 35]], - "momxyz": [[0, 1, 0], [0, 2, 0]], - "rv": 0.2, - "component": 2, - }, - ], - } - - dashboard = build_extended_property_data(eeg, 0, 2) - - assert dashboard.dipfit is not None - assert dashboard.dipfit.coordformat == "MNI" - assert dashboard.dipfit.rv_percent == pytest.approx(20.0) - assert dashboard.dipfit.dmr == pytest.approx(2.0) - first_dipfit = resolve_dipfit_data(eeg, 1) - assert first_dipfit is not None - np.testing.assert_allclose(dashboard.dipfit.positions, [[25, 10, 35], [-25, 10, 35]]) - np.testing.assert_allclose(first_dipfit.positions, [[0, -20, 40]]) - - -def test_dipfit_data_rejects_malformed_localized_model() -> None: - eeg = _iclabel_eeg() - eeg["dipfit"] = {"model": [{"posxyz": [1, 2], "momxyz": [1, 0, 0], "rv": 0.1}]} - - with pytest.raises(ValueError, match="posxyz rows with 3 coordinates"): - resolve_dipfit_data(eeg, 1) - - -def test_missing_classifier_falls_back_to_lightweight_viewprops_display() -> None: - eeg = _iclabel_eeg(include_classifier=False) - - figures = pop_viewprops(eeg, 0, [1, 2], plot=True, show_activity=False) - - assert len(figures) == 1 - assert not hasattr(figures[0], "eegprep_dashboard_data") - assert len(figures[0].eegprep_activity_views) == 2 - assert figures[0].eegprep_activity_views[0].state.events - plt.close(figures[0]) - - -def test_pop_prop_extended_facade_preserves_public_helper_imports() -> None: - assert pop_prop_extended_module.build_extended_property_data is _prop_numerics.build_extended_property_data - assert pop_prop_extended_module.resolve_classifier_data is _prop_numerics.resolve_classifier_data - assert pop_prop_extended_module.resolve_dipfit_data is _prop_numerics.resolve_dipfit_data - assert _prop_browser.build_navigable_dashboard.__module__.endswith("._prop_browser") diff --git a/tests/test_rejection_workflows.py b/tests/test_rejection_workflows.py deleted file mode 100644 index 09c470f6..00000000 --- a/tests/test_rejection_workflows.py +++ /dev/null @@ -1,791 +0,0 @@ -import copy -import os -from pathlib import Path -import shutil -import subprocess - -import matplotlib.pyplot as plt -import numpy as np -import pytest -import scipy.io - -import eegprep.functions.popfunc._eegplot_rejection as eegplot_rejection_module -import eegprep.functions.popfunc.pop_rejcont as pop_rejcont_module -from eegprep.functions.adminfunc.console import _console_python_command -from eegprep.functions.popfunc.pop_eegplot import DEFAULT_REJECTION_COLORS -from eegprep.functions.popfunc.eeg_rejsuperpose import eeg_rejsuperpose -from eegprep.functions.popfunc._rejection import ( - jointprob, - jointprob_marks, - kurtosis_marks, - rejkurt, - trend_marks, -) -from eegprep.functions.popfunc.pop_autorej import pop_autorej -from eegprep.functions.popfunc.pop_eegthresh import pop_eegthresh -from eegprep.functions.popfunc.pop_jointprob import pop_jointprob -from eegprep.functions.popfunc.pop_loadset import pop_loadset -from eegprep.functions.popfunc.pop_rejchan import pop_rejchan -from eegprep.functions.popfunc.pop_rejcont import pop_rejcont -from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch -from eegprep.functions.popfunc.pop_rejkurt import pop_rejkurt -from eegprep.functions.popfunc.pop_rejmenu import pop_rejmenu -from eegprep.functions.popfunc.pop_rejspec import pop_rejspec -from eegprep.functions.popfunc.pop_rejtrend import pop_rejtrend -from eegprep.functions.popfunc.pop_selectcomps import pop_selectcomps -from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops -from tests.fixtures import SAMPLE_DATASET_PATH, create_test_eeg - - -def _epoched_eeg() -> dict: - rng = np.random.default_rng(4) - eeg = create_test_eeg(n_channels=4, n_samples=80, n_trials=5, srate=100) - data = rng.normal(0, 0.05, (4, 80, 5)) - data[0, 10:20, 1] = 25 - data[1, :, 2] += np.linspace(0, 8, 80) - data[2, :, 3] += 4 * np.sin(2 * np.pi * 25 * np.arange(80) / 100) - eeg["data"] = data - eeg["icaweights"] = np.eye(4) - eeg["icasphere"] = np.eye(4) - eeg["icawinv"] = np.eye(4) - eeg["icachansind"] = np.arange(4) - eeg["icaact"] = None - eeg["reject"] = {} - return eeg - - -def _reference_trend_marks( - data: np.ndarray, selected_rows: list[int], *, winsize: int, maxslope: float, min_r: float -) -> np.ndarray: - row_marks = np.zeros((data.shape[0], data.shape[2]), dtype=bool) - x = np.linspace(1 / winsize, 1, winsize) - tolerance = 1000 * winsize * 1.1921e-7 - for row_index in selected_rows: - for trial in range(data.shape[2]): - for start in range(0, data.shape[1] - winsize + 1, winsize): - y = data[row_index, start : start + winsize, trial] - slope, intercept = np.polyfit(x, y, 1) - fit = slope * x + intercept - sst = max(float(np.sum((y - y.mean()) ** 2)), tolerance) - r2 = 1 - float(np.sum((y - fit) ** 2)) / sst - if abs(slope) >= maxslope and r2 > min_r: - row_marks[row_index, trial] = True - break - return row_marks - - -def test_pop_eegthresh_marks_epochs_and_emits_replayable_python(): - eeg = _epoched_eeg() - - out, com = pop_eegthresh(eeg, 1, [1], -10, 10, 0, 0.79, 0, 0, return_com=True) - - assert out["reject"]["rejthresh"].tolist() == [False, True, False, False, False] - assert out["reject"]["rejthreshE"][0].tolist() == [False, True, False, False, False] - assert _console_python_command(com) == ( - "EEG = pop_eegthresh(EEG, icacomp=1, elecrange=[1], negthresh=[-10], " - "posthresh=[10], starttime=[0], endtime=[0.79], superpose=0, reject=0)" - ) - - -def test_rejection_statistics_store_data_and_component_marks(): - eeg = _epoched_eeg() - - prob_out, _local, _global, prob_count = pop_jointprob(eeg, 1, [1, 2, 3, 4], 1.2, 1.2, 0, 0) - kurt_out, _local, _global, kurt_count = pop_rejkurt(eeg, 1, [1, 2, 3, 4], 1.2, 1.2, 0, 0) - trend_out = pop_rejtrend(eeg, 1, [2], 80, 0.2, 0.3, 0, 0) - spec_out, spec_indices = pop_rejspec( - eeg, - 1, - "method", - "multitaper", - "elecrange", - [3], - "threshold", - [-10, 10], - "freqlimits", - [20, 30], - "eegplotreject", - 0, - ) - fft_spec_out, fft_spec_indices = pop_rejspec( - eeg, - 1, - "method", - "fft", - "elecrange", - [3], - "threshold", - [-10, 10], - "freqlimits", - [20, 30], - "eegplotreject", - 0, - ) - comp_out, _local, _global, comp_count = pop_jointprob(eeg, 0, [1, 2, 3, 4], 1.2, 1.2, 0, 0) - - assert prob_count >= 1 - assert kurt_count >= 0 - assert prob_out["reject"]["rejjpE"].shape == (4, 5) - assert kurt_out["reject"]["rejkurtE"].shape == (4, 5) - assert trend_out["reject"]["rejconst"][2] - assert spec_indices - assert fft_spec_indices - assert spec_out["specdata"].shape[:2] == (4, 40) - assert fft_spec_out["specdata"].shape == spec_out["specdata"].shape - assert not np.allclose(fft_spec_out["specdata"], spec_out["specdata"]) - assert comp_count >= 1 - assert "icarejjp" in comp_out["reject"] - - -@pytest.mark.parametrize( - ("runner", "field"), - [ - ( - lambda eeg, callback: pop_eegthresh( - eeg, - 1, - [1], - -10, - 10, - 0, - 0.79, - 0, - 0, - topcommand="update", - command_callback=callback, - return_com=True, - ), - "rejthresh", - ), - ( - lambda eeg, callback: pop_jointprob( - eeg, 1, [1, 2, 3, 4], 1.2, 1.2, 0, 0, 1, command_callback=callback, return_com=True - ), - "rejjp", - ), - ( - lambda eeg, callback: pop_rejkurt( - eeg, 1, [1, 2, 3, 4], 1.2, 1.2, 0, 0, 1, command_callback=callback, return_com=True - ), - "rejkurt", - ), - ( - lambda eeg, callback: pop_rejtrend( - eeg, 1, [2], 80, 0.2, 0.3, 0, 0, 1, command_callback=callback, return_com=True - ), - "rejconst", - ), - ( - lambda eeg, callback: pop_rejspec( - eeg, - 1, - "method", - "fft", - "elecrange", - [3], - "threshold", - [-10, 10], - "freqlimits", - [20, 30], - "eegplotplotallrej", - 2, - "eegplotreject", - 0, - command_callback=callback, - return_com=True, - ), - "rejfreq", - ), - ], -) -def test_epoched_rejection_display_paths_open_browser_and_accept_marks(monkeypatch, runner, field): - calls = [] - accepted = [] - - def fake_eegplot(data, *args, **kwargs): - del args - calls.append((np.asarray(data), kwargs)) - kwargs["command_callback"](kwargs["winrej"]) - return "window" - - monkeypatch.setattr(eegplot_rejection_module, "eegplot", fake_eegplot) - eeg = _epoched_eeg() - - out, command = runner(eeg, lambda eeg_out, accept_command: accepted.append((eeg_out, accept_command))) - - assert command - assert out["trials"] == eeg["trials"] - assert calls - assert calls[0][0].ndim == 3 - assert accepted - assert accepted[0][1] == command - assert accepted[0][0]["reject"][field].shape == (eeg["trials"],) - - -def test_component_rejection_browser_accept_updates_ica_marks(monkeypatch): - calls = [] - accepted = [] - - def fake_eegplot(data, *args, **kwargs): - del args - calls.append((np.asarray(data), kwargs)) - kwargs["command_callback"](kwargs["winrej"]) - return "window" - - monkeypatch.setattr(eegplot_rejection_module, "eegplot", fake_eegplot) - eeg = _epoched_eeg() - - out, _command = pop_jointprob( - eeg, - 0, - [1], - 1.2, - 1.2, - 0, - 0, - 1, - command_callback=lambda eeg_out, command: accepted.append((eeg_out, command)), - return_com=True, - ) - - assert calls[0][0].shape[0] == 1 - assert out["trials"] == eeg["trials"] - assert "icarejjp" in accepted[0][0]["reject"] - assert accepted[0][0]["reject"]["icarejjpE"].shape == (4, 5) - - -def test_reject_on_browser_accept_removes_epochs_without_immediate_rejection(monkeypatch): - accepted = [] - - def fake_eegplot(_data, *args, **kwargs): - del args - kwargs["command_callback"](kwargs["winrej"]) - return "window" - - monkeypatch.setattr(eegplot_rejection_module, "eegplot", fake_eegplot) - eeg = _epoched_eeg() - - out, _command = pop_eegthresh( - eeg, - 1, - [1], - -10, - 10, - 0, - 0.79, - 0, - 1, - topcommand="reject", - command_callback=lambda eeg_out, command: accepted.append(eeg_out), - return_com=True, - ) - - assert out["trials"] == eeg["trials"] - assert accepted[0]["trials"] == eeg["trials"] - 1 - - -def test_superposed_browser_winrej_includes_existing_family_marks(monkeypatch): - calls = [] - - def fake_eegplot(_data, *args, **kwargs): - del args - calls.append(kwargs) - return "window" - - monkeypatch.setattr(eegplot_rejection_module, "eegplot", fake_eegplot) - eeg = _epoched_eeg() - eeg["reject"]["rejthresh"] = np.array([True, False, False, False, False]) - eeg["reject"]["rejthreshE"] = np.zeros((4, 5), dtype=bool) - eeg["reject"]["rejthreshE"][0, 0] = True - eeg["reject"]["disprej"] = ["thresh"] - - pop_jointprob(eeg, 1, [1, 2, 3, 4], 1.2, 1.2, 2, 0, 1) - - rows = calls[0]["winrej"] - assert any(np.allclose(row[2:5], DEFAULT_REJECTION_COLORS["thresh"]) for row in rows) - - -def test_pop_rejcont_display_accept_removes_continuous_regions(monkeypatch): - accepted = [] - eeg = create_test_eeg(n_channels=2, n_samples=120, n_trials=1, srate=100) - eeg["chanlocs"] = np.asarray(eeg["chanlocs"], dtype=object) - time = np.arange(120) / 100 - eeg["data"][0] = 100 * np.sin(2 * np.pi * 30 * time) - - def fake_eegplot(data, *args, **kwargs): - del args - assert np.asarray(data).shape[0] == 1 - assert kwargs["eloc_file"][0]["labels"] == "Ch1" - kwargs["command_callback"](kwargs["winrej"]) - return "window" - - monkeypatch.setattr(pop_rejcont_module, "eegplot", fake_eegplot) - - out, selected = pop_rejcont( - eeg, - "elecrange", - [1], - "freqlimit", - [20, 40], - "threshold", - 0, - "epochlength", - 0.2, - "contiguous", - 1, - "eegplot", - "on", - command_callback=lambda eeg_out, command: accepted.append((eeg_out, command)), - ) - - assert selected.size - assert out["pnts"] == eeg["pnts"] - assert accepted[0][0]["pnts"] < eeg["pnts"] - - -def test_pop_rejcont_display_defers_history_command_until_browser_accept(monkeypatch): - accepted = [] - eeg = create_test_eeg(n_channels=2, n_samples=120, n_trials=1, srate=100) - time = np.arange(120) / 100 - eeg["data"][0] = 100 * np.sin(2 * np.pi * 30 * time) - - def fake_eegplot(_data, *args, **kwargs): - del args - kwargs["command_callback"](kwargs["winrej"]) - return "window" - - monkeypatch.setattr(pop_rejcont_module, "eegplot", fake_eegplot) - - out, command = pop_rejcont( - eeg, - "elecrange", - [1], - "freqlimit", - [20, 40], - "threshold", - 0, - "epochlength", - 0.2, - "contiguous", - 1, - "eegplot", - "on", - command_callback=lambda eeg_out, accept_command: accepted.append((eeg_out, accept_command)), - return_com=True, - ) - - assert out is eeg - assert command == "" - assert accepted[0][1].startswith("EEG = pop_rejcont(EEG, ") - - -def test_pop_autorej_display_marks_original_epochs_before_browser_accept(monkeypatch): - calls = [] - accepted = [] - - def fake_eegplot(_data, *args, **kwargs): - del args - calls.append(kwargs) - kwargs["command_callback"](kwargs["winrej"]) - return "window" - - monkeypatch.setattr(eegplot_rejection_module, "eegplot", fake_eegplot) - eeg = _epoched_eeg() - - out, command = pop_autorej( - eeg, - "threshold", - 10, - "startprob", - 20, - "maxrej", - 40, - "nogui", - "on", - "eegplot", - "on", - command_callback=lambda eeg_out, accept_command: accepted.append((eeg_out, accept_command)), - return_com=True, - ) - - assert command - assert calls - assert out["trials"] == eeg["trials"] - assert out["reject"]["rejauto"].shape == (eeg["trials"],) - assert accepted[0][1] == command - - -def test_jointprob_global_marks_match_eeglab_trial_rows_for_duplicate_channels(): - rng = np.random.default_rng(42) - data = rng.normal(size=(3, 12, 4)) - elecrange = [3, 1, 3] - - reject, row_marks, local_scores, global_scores = jointprob_marks(data, elecrange, 0.8, 0.8) - - selected = [2, 0, 2] - expected_local_scores, expected_local = jointprob(data[selected], 0.8, normalize=1) - global_data = data[selected].transpose(2, 0, 1).reshape(data.shape[2], -1) - expected_global_scores, expected_global = jointprob(global_data, 0.8, normalize=1) - expected_reject = expected_local.any(axis=0) | expected_global.ravel() - - np.testing.assert_allclose(local_scores, expected_local_scores) - np.testing.assert_allclose(global_scores, expected_global_scores.ravel()) - np.testing.assert_array_equal(reject, expected_reject) - np.testing.assert_array_equal(row_marks[0], expected_local[1]) - np.testing.assert_array_equal(row_marks[2], expected_local[2]) - - -def test_jointprob_global_threshold_can_reject_when_local_threshold_does_not(): - data = np.array( - [ - [ - [4.0, 3.0], - [2.0, 1.0], - [1.0, 0.0], - [0.0, 0.0], - [0.0, 4.0], - [3.0, 4.0], - [2.0, 3.0], - [4.0, 3.0], - [3.0, 2.0], - [2.0, 4.0], - [1.0, 4.0], - [3.0, 0.0], - ], - [ - [1.0, 4.0], - [2.0, 0.0], - [3.0, 3.0], - [4.0, 0.0], - [0.0, 4.0], - [0.0, 2.0], - [0.0, 1.0], - [2.0, 2.0], - [2.0, 0.0], - [0.0, 0.0], - [0.0, 3.0], - [2.0, 3.0], - ], - ] - ) - - reject, row_marks, _local_scores, global_scores = jointprob_marks(data, [1, 2], 10, 0.5) - - np.testing.assert_array_equal(row_marks.any(axis=0), [False, False]) - np.testing.assert_allclose(global_scores, [1 / np.sqrt(2), -1 / np.sqrt(2)]) - np.testing.assert_array_equal(reject, [True, True]) - - -def test_kurtosis_global_marks_match_eeglab_trial_rows_for_duplicate_channels(): - rng = np.random.default_rng(7) - data = rng.normal(size=(3, 16, 4)) - elecrange = [2, 1, 2] - - reject, row_marks, local_scores, global_scores = kurtosis_marks(data, elecrange, 0.6, 0.6) - - selected = [1, 0, 1] - expected_local_scores, expected_local = rejkurt(data[selected], 0.6, normalize=1) - global_data = data[selected].transpose(2, 0, 1).reshape(data.shape[2], -1) - expected_global_scores, expected_global = rejkurt(global_data, 0.6, normalize=1) - expected_reject = expected_local.any(axis=0) | expected_global.ravel() - - np.testing.assert_allclose(local_scores, expected_local_scores) - np.testing.assert_allclose(global_scores, expected_global_scores.ravel()) - np.testing.assert_array_equal(reject, expected_reject) - np.testing.assert_array_equal(row_marks[0], expected_local[1]) - np.testing.assert_array_equal(row_marks[1], expected_local[2]) - - -def test_kurtosis_global_threshold_can_reject_when_local_threshold_does_not(): - rng = np.random.default_rng(0) - data = rng.normal(size=(2, 12, 2)) - data[:, :, 1] *= 0.1 - - reject, row_marks, _local_scores, global_scores = kurtosis_marks(data, [1, 2], 10, 0.5) - - np.testing.assert_array_equal(row_marks.any(axis=0), [False, False]) - np.testing.assert_allclose(global_scores, [1 / np.sqrt(2), -1 / np.sqrt(2)]) - np.testing.assert_array_equal(reject, [True, True]) - - -def test_trend_marks_match_reference_window_loop(): - data = np.zeros((2, 12, 3), dtype=float) - data[0, :, 0] = np.arange(12, dtype=float) - data[0, :, 1] = 0.1 - data[1, :, 2] = np.r_[np.arange(6, dtype=float), np.zeros(6)] - - reject, row_marks = trend_marks(data, [1, 2], winsize=6, maxslope=0.3, min_r=0.8) - expected = _reference_trend_marks(data, [0, 1], winsize=6, maxslope=0.3, min_r=0.8) - - np.testing.assert_array_equal(row_marks, expected) - np.testing.assert_array_equal(reject, expected.any(axis=0)) - - -def test_eeg_rejsuperpose_and_pop_rejepoch_remove_marked_epochs(): - eeg = _epoched_eeg() - eeg["reject"]["rejmanual"] = np.array([False, True, False, False, True]) - eeg["reject"]["rejmanualE"] = np.zeros((4, 5), dtype=bool) - eeg["reject"]["rejmanualE"][0, 1] = True - eeg["reject"]["rejmanualE"][1, 4] = True - - marked, com = eeg_rejsuperpose(eeg, 1, 1, 0, 0, 0, 0, 0, 0, return_com=True) - - assert marked["reject"]["rejglobal"].tolist() == [False, True, False, False, True] - assert marked["reject"]["rejglobalE"].shape == (4, 5) - removed, reject_com = pop_rejepoch(copy.deepcopy(marked), marked["reject"]["rejglobal"], 0, return_com=True) - assert removed["trials"] == 3 - assert _console_python_command(com) == "EEG = eeg_rejsuperpose(EEG, 1, 1, 0, 0, 0, 0, 0, 0)" - assert _console_python_command(reject_com) == "EEG = pop_rejepoch(EEG, tmprej=[2, 5], confirm=0)" - - -def test_eeg_rejsuperpose_only_crosses_trial_marks_between_data_and_ica_families(): - eeg = _epoched_eeg() - eeg["icaweights"] = np.ones((2, 4)) - eeg["icawinv"] = np.ones((4, 2)) - eeg["reject"] = { - "rejmanual": np.array([False, True, False, False, False]), - "rejmanualE": np.zeros((4, 5), dtype=bool), - "icarejmanual": np.array([False, False, True, False, False]), - "icarejmanualE": np.ones((2, 5), dtype=bool), - } - - marked = eeg_rejsuperpose(eeg, 1, 1, 0, 0, 0, 0, 0, 1) - - assert marked["reject"]["rejglobal"].tolist() == [False, True, True, False, False] - assert marked["reject"]["rejglobalE"].shape == (4, 5) - assert not marked["reject"]["rejglobalE"].any() - - -@pytest.mark.matlab -def test_eeg_rejsuperpose_matches_eeglab_for_deterministic_marks(tmp_path): - if os.environ.get("EEGPREP_SKIP_MATLAB") == "1": - pytest.skip("MATLAB tests disabled via EEGPREP_SKIP_MATLAB") - matlab = shutil.which("matlab") - if matlab is None: - pytest.skip("MATLAB executable not available") - eeglab_root = _eeglab_root() - if eeglab_root is None: - pytest.skip("EEGLAB source not available for parity reference") - - eeg = _epoched_eeg() - reject = { - "rejmanual": np.array([False, True, False, False, False]), - "rejmanualE": np.array( - [ - [False, True, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ] - ), - "rejthresh": np.array([False, False, True, False, False]), - "rejthreshE": np.array( - [ - [False, False, True, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ] - ), - "rejconst": np.zeros(5, dtype=bool), - "rejconstE": np.zeros((4, 5), dtype=bool), - "rejjp": np.array([False, False, False, True, False]), - "rejjpE": np.array( - [ - [False, False, False, True, False], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ] - ), - "rejkurt": np.zeros(5, dtype=bool), - "rejkurtE": np.zeros((4, 5), dtype=bool), - "rejfreq": np.array([True, False, False, False, False]), - "rejfreqE": np.array( - [ - [True, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ] - ), - } - eeg["reject"] = reject - py_out = eeg_rejsuperpose(eeg, 1, 1, 1, 1, 1, 1, 1, 0) - - script = tmp_path / "eeg_rejsuperpose_parity.m" - output = tmp_path / "out.mat" - script.write_text(_matlab_rejsuperpose_script(eeglab_root, output), encoding="utf-8") - result = subprocess.run( - [matlab, "-batch", f"run('{script.as_posix()}')"], check=False, capture_output=True, text=True - ) - if result.returncode: - pytest.fail(result.stdout + result.stderr) - matlab_out = scipy.io.loadmat(output) - - np.testing.assert_array_equal( - np.asarray(py_out["reject"]["rejglobal"], dtype=bool), matlab_out["rejglobal"].ravel() - ) - np.testing.assert_array_equal(np.asarray(py_out["reject"]["rejglobalE"], dtype=bool), matlab_out["rejglobalE"]) - - -def test_pop_rejmenu_can_combine_marks_without_browser(): - eeg = _epoched_eeg() - eeg["reject"]["rejthresh"] = np.array([False, True, False, False, False]) - eeg["reject"]["rejthreshE"] = np.zeros((4, 5), dtype=bool) - - out, com = pop_rejmenu(eeg, 1, gui=False, return_com=True) - - assert out["reject"]["rejglobal"].tolist() == [False, True, False, False, False] - assert _console_python_command(com) == "EEG = eeg_rejsuperpose(EEG, 1, 1, 1, 1, 1, 1, 1, 1)" - - -def test_pop_autorej_preserves_original_epoch_numbers_during_iterative_rejection(): - eeg = _epoched_eeg() - - out, rejected = pop_autorej(eeg, "threshold", 10, "startprob", 20, "maxrej", 40, "nogui", "on") - - assert eeg["trials"] - out["trials"] == len(rejected) - assert rejected == sorted(set(rejected)) - - -def test_channel_and_continuous_rejection_work_on_sample_data_without_ica(): - sample = pop_loadset(SAMPLE_DATASET_PATH) - - _, rejected_channels, measure = pop_rejchan(sample, "measure", "std", "threshold", 1e9, "indexonly", "on") - _, selected_regions = pop_rejcont( - sample, - "elecrange", - [1], - "threshold", - 1e9, - "epochlength", - 0.5, - "contiguous", - 1, - "onlyreturnselection", - "on", - ) - - assert rejected_channels == [] - assert measure.shape == (32,) - assert selected_regions.shape == (0, 2) - with pytest.raises(ValueError, match="ICA decomposition is required"): - pop_eegthresh(sample, 0, [1], -10, 10, 0, 1) - - -def test_rejection_component_threshold_recomputes_stale_stored_icaact(): - eeg = _epoched_eeg() - eeg["icaweights"] = 2.0 * np.eye(4) - eeg["icasphere"] = np.eye(4) - eeg["icaact"] = np.zeros((4, eeg["pnts"], eeg["trials"])) - - out, rejected = pop_eegthresh(eeg, 0, [1], -40, 40, 0, 0.79, 0, 0) - - assert rejected == [2] - assert out["reject"]["icarejthresh"].tolist() == [False, True, False, False, False] - - -def test_pop_rejchan_default_threshold_matches_gui_zscore_default(): - eeg = create_test_eeg(n_channels=2, n_samples=20, n_trials=1, srate=100) - eeg["data"] = np.zeros((2, 20)) - eeg["data"][0, 10] = 100 - - _out, rejected_channels, _measure = pop_rejchan(eeg, "measure", "std", "indexonly", "on") - - assert rejected_channels == [1] - - -def test_pop_rejcont_history_replays_effectful_mode_and_overlap_options(): - sample = pop_loadset(SAMPLE_DATASET_PATH) - - _out, command = pop_rejcont( - sample, - "elecrange", - [1], - "freqlimit", - [20, 40], - "threshold", - 1e9, - "epochlength", - 0.5, - "overlap", - 0.1, - "mode", - "mean", - "onlyreturnselection", - "on", - return_com=True, - ) - - assert _console_python_command(command) == ( - "EEG = pop_rejcont(EEG, elecrange=[1], freqlimit=[20, 40], threshold=1000000000, " - "epochlength=0.5, overlap=0.1, mode='mean', onlyreturnselection='on')" - ) - - -def test_component_selection_and_viewprops_are_replayable_without_scrolling_browser(): - eeg = _epoched_eeg() - - selected, select_com = pop_selectcomps(eeg, [1, 3], reject=[2], plot=False, return_com=True) - figures, props_com = pop_viewprops(eeg, 0, [1, 2], plot=False, return_com=True) - - assert selected["reject"]["gcompreject"].tolist() == [False, True, False, False] - assert figures == [] - assert _console_python_command(select_com) == "EEG = pop_selectcomps(EEG, compnum=[1, 3], reject=[2])" - assert _console_python_command(props_com) == ( - "pop_viewprops(EEG, typecomp=0, chanorcomp=[1, 2], spec_opt=[], erp_opt=[], scroll_event=1, classifier_name='')" - ) - - -def test_gui_cancel_paths_leave_datasets_unchanged(): - class CancelRenderer: - def run(self, spec, initial_values=None): - return None - - eeg = _epoched_eeg() - out, com = pop_eegthresh(eeg, gui=True, renderer=CancelRenderer(), return_com=True) - rejchan_out, rejchan_com = pop_rejchan(copy.deepcopy(eeg), gui=True, renderer=CancelRenderer(), return_com=True) - - assert out is eeg - assert com == "" - assert rejchan_out["data"].shape == eeg["data"].shape - assert rejchan_com == "" - plt.close("all") - - -def _eeglab_root() -> Path | None: - candidates = [] - if os.environ.get("EEGPREP_EEGLAB_ROOT"): - candidates.append(Path(os.environ["EEGPREP_EEGLAB_ROOT"])) - candidates.append(Path(__file__).resolve().parents[1] / "src" / "eegprep" / "eeglab") - for candidate in candidates: - if (candidate / "functions" / "popfunc" / "eeg_rejsuperpose.m").exists(): - return candidate - return None - - -def _matlab_rejsuperpose_script(eeglab_root: Path, output: Path) -> str: - return f""" -addpath(fullfile('{eeglab_root.as_posix()}', 'functions', 'popfunc')); -EEG = struct(); -EEG.trials = 5; -EEG.nbchan = 4; -EEG.reject = struct(); -EEG.reject.rejmanual = logical([0 1 0 0 0]); -EEG.reject.rejmanualE = logical([0 1 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0]); -EEG.reject.rejthresh = logical([0 0 1 0 0]); -EEG.reject.rejthreshE = logical([0 0 1 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0]); -EEG.reject.rejconst = logical([0 0 0 0 0]); -EEG.reject.rejconstE = logical(zeros(4, 5)); -EEG.reject.rejjp = logical([0 0 0 1 0]); -EEG.reject.rejjpE = logical([0 0 0 1 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0]); -EEG.reject.rejkurt = logical([0 0 0 0 0]); -EEG.reject.rejkurtE = logical(zeros(4, 5)); -EEG.reject.rejfreq = logical([1 0 0 0 0]); -EEG.reject.rejfreqE = logical([1 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0]); -EEG = eeg_rejsuperpose(EEG, 1, 1, 1, 1, 1, 1, 1, 0); -rejglobal = EEG.reject.rejglobal; -rejglobalE = EEG.reject.rejglobalE; -save('{output.as_posix()}', 'rejglobal', 'rejglobalE'); -""" diff --git a/tests/test_sample_data_pop_functions.py b/tests/test_sample_data_pop_functions.py index 1137e8a5..07bfa3c1 100644 --- a/tests/test_sample_data_pop_functions.py +++ b/tests/test_sample_data_pop_functions.py @@ -54,13 +54,6 @@ from eegprep.plugins.EEG_BIDS.bids_tools import pop_eventinfo, pop_participantinfo, pop_taskinfo, validate_bids from eegprep.plugins.EEG_BIDS.pop_exportbids import pop_exportbids from eegprep.plugins.EEG_BIDS.pop_importbids import pop_importbids -from eegprep.plugins.ICLabel.pop_iclabel import pop_iclabel -from eegprep.plugins.ICLabel.pop_icflag import DEFAULT_ICFLAG_THRESHOLDS, pop_icflag -from eegprep.plugins.clean_rawdata.clean_artifacts import clean_artifacts -from eegprep.plugins.clean_rawdata.clean_asr import clean_asr -from eegprep.plugins.clean_rawdata.clean_channels import clean_channels -from eegprep.plugins.clean_rawdata.clean_windows import clean_windows -from eegprep.plugins.clean_rawdata.pop_clean_rawdata import pop_clean_rawdata SAMPLE_SET = Path("sample_data/eeglab_data.set") diff --git a/tests/test_spatial.py b/tests/test_spatial.py deleted file mode 100644 index e6a6965e..00000000 --- a/tests/test_spatial.py +++ /dev/null @@ -1,187 +0,0 @@ -""" -Tests for eegprep.plugins.clean_rawdata.private.sphericalSplineInterpolate module. - -This module tests the spherical spline interpolation functions used in topographic plotting. -""" - -import unittest -import numpy as np -from eegprep.plugins.clean_rawdata.private.sphericalSplineInterpolate import sphericalSplineInterpolate, _interpMx - - -class TestInterpMx(unittest.TestCase): - """Test the _interpMx helper function for Legendre polynomial approximation.""" - - def test_interp_mx_basic(self): - """Test basic _interpMx functionality.""" - # Test with simple cosine values - cosang = np.array([[1.0, 0.5, 0.0, -0.5, -1.0]]) - G, H = _interpMx(cosang, order=4, tol=1e-10) - - # Should return matrices with same shape as input - self.assertEqual(G.shape, cosang.shape) - self.assertEqual(H.shape, cosang.shape) - - # Should be finite - self.assertTrue(np.all(np.isfinite(G))) - self.assertTrue(np.all(np.isfinite(H))) - - def test_interp_mx_single_point(self): - """Test _interpMx with single point.""" - cosang = np.array([[1.0]]) - G, H = _interpMx(cosang, order=4, tol=1e-10) - - self.assertEqual(G.shape, (1, 1)) - self.assertEqual(H.shape, (1, 1)) - self.assertTrue(np.isfinite(G[0, 0])) - self.assertTrue(np.isfinite(H[0, 0])) - - def test_interp_mx_edge_values(self): - """Test _interpMx with edge cosine values.""" - # Test extreme cosine values - cosang = np.array([[-1.0, -0.999, 0.0, 0.999, 1.0]]) - G, H = _interpMx(cosang, order=4, tol=1e-10) - - # All values should be finite - self.assertTrue(np.all(np.isfinite(G))) - self.assertTrue(np.all(np.isfinite(H))) - - def test_interp_mx_numerical_stability(self): - """Test _interpMx numerical stability with various inputs.""" - # Test with very small differences - cosang = np.array([[1.0, 0.9999, 0.9998, 0.9997]]) - G, H = _interpMx(cosang, order=4, tol=1e-10) - - self.assertTrue(np.all(np.isfinite(G))) - self.assertTrue(np.all(np.isfinite(H))) - - def test_interp_mx_empty_input(self): - """Test _interpMx with empty input.""" - cosang = np.array([[]]) - G, H = _interpMx(cosang, order=4, tol=1e-10) - - self.assertEqual(G.size, 0) - self.assertEqual(H.size, 0) - - -class TestSphericalSplineInterpolate(unittest.TestCase): - """Test the sphericalSplineInterpolate function.""" - - def setUp(self): - """Set up test fixtures.""" - # Create simple 3D electrode positions (as expected by the function) - # 4 electrodes on unit sphere - self.src_positions = np.array( - [ - [1.0, 0.0, 0.0, -1.0], # x coordinates - [0.0, 1.0, 0.0, 0.0], # y coordinates - [0.0, 0.0, 1.0, 0.0], # z coordinates - ] - ) - - # Destination positions (where to interpolate) - self.dest_positions = np.array( - [ - [0.707, -0.707], # x coordinates - [0.707, 0.707], # y coordinates - [0.0, 0.0], # z coordinates - ] - ) - - def test_spherical_spline_interpolate_basic(self): - """Test basic sphericalSplineInterpolate functionality.""" - W, Gss, Gds, Hds = sphericalSplineInterpolate(self.src_positions, self.dest_positions) - - # Check return value shapes - n_src = self.src_positions.shape[1] - n_dest = self.dest_positions.shape[1] - - self.assertEqual(W.shape, (n_dest, n_src)) - self.assertEqual(Gss.shape, (n_src, n_src)) - self.assertEqual(Gds.shape, (n_dest, n_src)) - self.assertEqual(Hds.shape, (n_dest, n_src)) - - # All values should be finite - self.assertTrue(np.all(np.isfinite(W))) - self.assertTrue(np.all(np.isfinite(Gss))) - self.assertTrue(np.all(np.isfinite(Gds))) - self.assertTrue(np.all(np.isfinite(Hds))) - - def test_spherical_spline_interpolate_different_types(self): - """Test sphericalSplineInterpolate with different interpolation types.""" - # Test 'spline' type (default) - W_spline, _, _, _ = sphericalSplineInterpolate(self.src_positions, self.dest_positions, type='spline') - - # Test 'slap' type - W_slap, _, _, _ = sphericalSplineInterpolate(self.src_positions, self.dest_positions, type='slap') - - # Both should return finite matrices of the same shape - self.assertEqual(W_spline.shape, W_slap.shape) - self.assertTrue(np.all(np.isfinite(W_spline))) - self.assertTrue(np.all(np.isfinite(W_slap))) - - # Results should be different for different interpolation types - self.assertFalse(np.allclose(W_spline, W_slap)) - - def test_spherical_spline_interpolate_regularization(self): - """Test sphericalSplineInterpolate with different regularization parameters.""" - # Test with different lambda values - W_low_reg, _, _, _ = sphericalSplineInterpolate(self.src_positions, self.dest_positions, lambda_reg=1e-10) - - W_high_reg, _, _, _ = sphericalSplineInterpolate(self.src_positions, self.dest_positions, lambda_reg=1e-2) - - # Both should be finite and same shape - self.assertEqual(W_low_reg.shape, W_high_reg.shape) - self.assertTrue(np.all(np.isfinite(W_low_reg))) - self.assertTrue(np.all(np.isfinite(W_high_reg))) - - def test_spherical_spline_interpolate_different_orders(self): - """Test sphericalSplineInterpolate with different polynomial orders.""" - # Test with different orders - W_order2, _, _, _ = sphericalSplineInterpolate(self.src_positions, self.dest_positions, order=2) - - W_order6, _, _, _ = sphericalSplineInterpolate(self.src_positions, self.dest_positions, order=6) - - # Both should be finite and same shape - self.assertEqual(W_order2.shape, W_order6.shape) - self.assertTrue(np.all(np.isfinite(W_order2))) - self.assertTrue(np.all(np.isfinite(W_order6))) - - -class TestSphericalSplineInterpolateErrorHandling(unittest.TestCase): - """Test error handling in sphericalSplineInterpolate.""" - - def test_invalid_input_dimensions(self): - """Test error handling for invalid input dimensions.""" - # Wrong shape for src (should be 3 x N) - src_wrong = np.array([[1, 2], [3, 4]]) # 2x2 instead of 3xN - dest = np.array([[0.5], [0.5], [0.0]]) # 3x1 - - with self.assertRaises(ValueError): - sphericalSplineInterpolate(src_wrong, dest) - - def test_invalid_interpolation_type(self): - """Test error handling for invalid interpolation type.""" - src = np.array([[1, 0], [0, 1], [0, 0]]) # 3x2 - dest = np.array([[0.5], [0.5], [0.0]]) # 3x1 - - with self.assertRaises(ValueError): - sphericalSplineInterpolate(src, dest, type='invalid_type') - - def test_empty_inputs(self): - """Test error handling for empty inputs.""" - empty_src = np.array([[], [], []]) # 3x0 - dest = np.array([[0.5], [0.5], [0.0]]) # 3x1 - - # This should either raise an error or handle gracefully - try: - W, _, _, _ = sphericalSplineInterpolate(empty_src, dest) - # If it succeeds, W should be empty - self.assertEqual(W.shape[1], 0) - except (ValueError, np.linalg.LinAlgError): - # Expected behavior for empty input - pass - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_utils_asr.py b/tests/test_utils_asr.py deleted file mode 100644 index 6e4f22a9..00000000 --- a/tests/test_utils_asr.py +++ /dev/null @@ -1,738 +0,0 @@ -import unittest -import numpy as np -from unittest.mock import patch - -from eegprep.plugins.clean_rawdata.asr_calibrate import asr_calibrate -from eegprep.plugins.clean_rawdata.asr_process import asr_process -from eegprep.plugins.clean_rawdata.clean_asr import clean_asr - - -class TestAsrCalibrate(unittest.TestCase): - """Test the asr_calibrate function.""" - - def setUp(self): - """Set up test fixtures with synthetic EEG data.""" - np.random.seed(42) # For reproducible tests - self.n_channels = 8 - self.n_samples = 1000 - self.srate = 250.0 - - # Create synthetic clean EEG data (zero-mean) - self.clean_data = np.random.randn(self.n_channels, self.n_samples) * 0.5 - - # Add some realistic structure (autocorrelation) - for i in range(self.n_channels): - # Simple AR(1) process for more realistic EEG-like data - for j in range(1, self.n_samples): - self.clean_data[i, j] += 0.8 * self.clean_data[i, j - 1] - - def test_basic_calibration(self): - """Test basic ASR calibration functionality.""" - state = asr_calibrate(self.clean_data, self.srate) - - # Check that all required state variables are present - required_keys = ['M', 'T', 'B', 'A', 'sos', 'iir_state'] - for key in required_keys: - self.assertIn(key, state) - - # Check matrix shapes - self.assertEqual(state['M'].shape, (self.n_channels, self.n_channels)) - self.assertEqual(state['T'].shape, (self.n_channels, self.n_channels)) - - # Check that M is symmetric and positive definite (within tolerance) - M = state['M'] - np.testing.assert_allclose(M, M.T, atol=1e-10) - eigenvals = np.linalg.eigvals(M) - self.assertTrue(np.all(eigenvals > -1e-10)) # Allow small numerical errors - - # Check that filter coefficients are finite - self.assertTrue(np.all(np.isfinite(state['B']))) - self.assertTrue(np.all(np.isfinite(state['A']))) - - def test_different_sampling_rates(self): - """Test calibration with different sampling rates and precomputed filters.""" - test_srates = [100, 128, 200, 256, 300, 500, 512] - - for srate in test_srates: - with self.subTest(srate=srate): - # Create appropriate data size for the sampling rate - n_samples = max(1000, int(srate * 2)) # At least 2 seconds - data = np.random.randn(4, n_samples) * 0.3 - - state = asr_calibrate(data, srate) - - # Check that state is valid - self.assertIsInstance(state, dict) - self.assertIn('M', state) - self.assertIn('T', state) - - # Check filter coefficients are reasonable - self.assertTrue(len(state['B']) > 1) - self.assertTrue(len(state['A']) > 0) - - def test_unsupported_sampling_rate_raises(self): - """Unsupported sampling rates must fail loudly, not silently degrade. - - Common rates like 999/1000/1024 Hz have no pre-computed spectral filter. - Substituting a trivial difference filter would silently miscalibrate ASR - thresholds, so asr_calibrate must raise rather than warn-and-continue. - """ - data = np.random.randn(4, 1000) * 0.3 - - for unsupported_srate in (999.0, 1000.0, 1024.0): - with self.subTest(srate=unsupported_srate): - with self.assertRaises(ValueError) as cm: - asr_calibrate(data, unsupported_srate) - self.assertIn('No pre-computed ASR spectral filter', str(cm.exception)) - - def test_unsupported_sampling_rate_allows_explicit_filter(self): - """An explicit B/A bypasses the precomputed-filter lookup for any srate.""" - data = np.random.randn(4, 1000) * 0.3 - B = np.array([1.0, -0.5]) - A = np.array([1.0]) - - state = asr_calibrate(data, 999.0, B=B, A=A) - self.assertIn('M', state) - np.testing.assert_array_equal(state['B'], B) - np.testing.assert_array_equal(state['A'], A) - - def test_parameter_validation(self): - """Test parameter validation and edge cases.""" - # Test with 1D data (should raise error) - with self.assertRaises(ValueError): - asr_calibrate(np.random.randn(100), self.srate) - - # Test with 3D data (should raise error) - with self.assertRaises(ValueError): - asr_calibrate(np.random.randn(4, 100, 10), self.srate) - - # Test with too little data - short_data = np.random.randn(self.n_channels, 50) - with self.assertRaises(ValueError): - asr_calibrate(short_data, self.srate) - - def test_custom_parameters(self): - """Test calibration with custom parameters.""" - state = asr_calibrate( - self.clean_data, - self.srate, - cutoff=3.0, - blocksize=20, - window_len=0.25, - window_overlap=0.5, - max_dropout_fraction=0.2, - min_clean_fraction=0.3, - maxmem=32, - ) - - # Should complete without errors - self.assertIsInstance(state, dict) - self.assertIn('M', state) - self.assertIn('T', state) - - def test_custom_filter_coefficients(self): - """Test calibration with custom filter coefficients.""" - # Simple custom filter - B = np.array([1.0, -0.5]) - A = np.array([1.0, -0.3]) - - state = asr_calibrate(self.clean_data, self.srate, B=B, A=A) - - # Check that custom coefficients are stored - np.testing.assert_array_equal(state['B'], B) - np.testing.assert_array_equal(state['A'], A) - - def test_riemannian_calibration(self): - """Test Riemannian ASR calibration variant.""" - # Mock the cov_mean function to avoid complex dependencies - with patch('eegprep.plugins.clean_rawdata.asr_calibrate.cov_mean') as mock_cov_mean: - mock_cov_mean.return_value = np.eye(self.n_channels) * 0.5 - - asr_calibrate(self.clean_data, self.srate, useriemannian='calib') - - # Should have called cov_mean with robust=True - mock_cov_mean.assert_called_once() - call_args = mock_cov_mean.call_args - self.assertTrue(call_args[1]['robust']) - - def test_nan_handling(self): - """Test handling of NaN values in input data.""" - data_with_nans = self.clean_data.copy() - data_with_nans[2, 100:110] = np.nan - data_with_nans[5, 500] = np.inf - - # Should not raise error - NaNs should be replaced with zeros - state = asr_calibrate(data_with_nans, self.srate) - - self.assertIsInstance(state, dict) - self.assertTrue(np.all(np.isfinite(state['M']))) - self.assertTrue(np.all(np.isfinite(state['T']))) - - def test_filter_divergence_error(self): - """Test error handling when IIR filter diverges.""" - # Mock scipy.signal.sosfilt to return NaN values (simulating filter divergence) - with patch('scipy.signal.sosfilt') as mock_sosfilt: - # Return data with NaN values to simulate filter divergence - mock_sosfilt.return_value = ( - np.full((4, 1000), np.nan), - np.zeros((2, 4, 2)), # Mock iir_state - ) - - with self.assertRaises(RuntimeError) as cm: - asr_calibrate(self.clean_data, self.srate) - - self.assertIn('IIR filter diverged', str(cm.exception)) - - def test_threshold_calculation_robustness(self): - """Test robustness of threshold calculation with edge cases.""" - # Create data with some extreme values - data = self.clean_data.copy() - data[:, :100] *= 10 # Add some "artifacts" - - with patch('eegprep.plugins.clean_rawdata.asr_calibrate.fit_eeg_distribution') as mock_fit: - # Mock successful fitting for most components - mock_fit.return_value = (1.0, 0.5, None, None) - - state = asr_calibrate(data, self.srate) - - # Should complete successfully - self.assertIsInstance(state, dict) - self.assertTrue(np.all(np.isfinite(state['T']))) - - def test_blocksize_calculation(self): - """Test automatic blocksize calculation based on memory constraints.""" - # Test with very low memory limit - state = asr_calibrate(self.clean_data, self.srate, maxmem=1) # 1 MB - - # Should still work, just with larger blocksize - self.assertIsInstance(state, dict) - - def test_geometric_median_fallback(self): - """Test fallback to geometric median when Riemannian method fails.""" - with patch('eegprep.plugins.clean_rawdata.asr_calibrate.cov_mean') as mock_cov_mean: - # Make cov_mean return NaNs to trigger fallback - mock_cov_mean.return_value = np.full((self.n_channels, self.n_channels), np.nan) - - with patch('eegprep.plugins.clean_rawdata.asr_calibrate.geometric_median') as mock_geom_median: - mock_geom_median.return_value = np.eye(self.n_channels).flatten() - - with self.assertLogs('eegprep.plugins.clean_rawdata.asr_calibrate', level='WARNING') as log: - asr_calibrate(self.clean_data, self.srate, useriemannian='calib') - - # Check that warning was logged and fallback was used - self.assertTrue(any('NaNs' in msg for msg in log.output)) - mock_geom_median.assert_called_once() - - -class TestAsrProcess(unittest.TestCase): - """Test the asr_process function.""" - - def setUp(self): - """Set up test fixtures.""" - np.random.seed(123) - self.n_channels = 6 - self.n_samples = 500 - self.srate = 200.0 - - # Create calibration data and state - calib_data = np.random.randn(self.n_channels, 1000) * 0.3 - self.state = asr_calibrate(calib_data, self.srate) - - # Create test data with some artifacts - self.test_data = np.random.randn(self.n_channels, self.n_samples) * 0.4 - # Add some artifacts to specific channels/times - self.test_data[2, 100:150] += np.random.randn(50) * 2.0 # Large artifacts - - def test_basic_processing(self): - """Test basic ASR processing functionality.""" - cleaned_data, new_state = asr_process(self.test_data, self.srate, self.state) - - # Check output shapes - self.assertEqual(cleaned_data.shape, self.test_data.shape) - - # Check that state is updated - self.assertIsInstance(new_state, dict) - self.assertIn('M', new_state) - self.assertIn('T', new_state) - self.assertIn('carry', new_state) - self.assertIn('cov', new_state) - - # Check that cleaned data is finite - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_empty_data_handling(self): - """Test processing with empty data.""" - empty_data = np.empty((self.n_channels, 0)) - cleaned_data, new_state = asr_process(empty_data, self.srate, self.state) - - # Should return empty data unchanged - self.assertEqual(cleaned_data.shape, (self.n_channels, 0)) - self.assertEqual(new_state['M'].shape, self.state['M'].shape) - - def test_parameter_validation(self): - """Test parameter validation.""" - # Test with 1D data - with self.assertRaises(ValueError): - asr_process(np.random.randn(100), self.srate, self.state) - - # Test with 3D data - with self.assertRaises(ValueError): - asr_process(np.random.randn(4, 100, 10), self.srate, self.state) - - def test_custom_processing_parameters(self): - """Test processing with custom parameters.""" - cleaned_data, new_state = asr_process( - self.test_data, self.srate, self.state, window_len=0.25, lookahead=0.1, step_size=16, max_dims=0.5 - ) - - # Should complete without errors - self.assertEqual(cleaned_data.shape, self.test_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_max_dims_as_integer(self): - """Test processing with max_dims specified as integer.""" - cleaned_data, new_state = asr_process( - self.test_data, - self.srate, - self.state, - max_dims=3, # Integer instead of fraction - ) - - self.assertEqual(cleaned_data.shape, self.test_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_nan_handling_in_processing(self): - """Test handling of NaN values during processing.""" - data_with_nans = self.test_data.copy() - data_with_nans[1, 50:60] = np.nan - data_with_nans[3, 200] = np.inf - - cleaned_data, new_state = asr_process(data_with_nans, self.srate, self.state) - - # Should handle NaNs gracefully - self.assertEqual(cleaned_data.shape, self.test_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_memory_management(self): - """Test memory management with large data chunks.""" - # Create larger data that might trigger splitting - large_data = np.random.randn(self.n_channels, 5000) * 0.5 - - with patch('psutil.virtual_memory') as mock_vm: - # Mock low available memory to trigger splitting - mock_vm.return_value.free = 50 * 1024**2 # 50 MB - - with self.assertLogs('eegprep.plugins.clean_rawdata.asr_process', level='INFO') as log: - cleaned_data, new_state = asr_process( - large_data, - self.srate, - self.state, - max_mem=10, # Low memory limit - ) - - # Check that splitting was logged - self.assertTrue(any('blocks' in msg for msg in log.output)) - - # Check output - self.assertEqual(cleaned_data.shape, large_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_memory_error_handling(self): - """Test error handling when memory is insufficient.""" - with patch('psutil.virtual_memory') as mock_vm: - # Mock extremely low memory - mock_vm.return_value.free = 1024 # 1 KB - - with self.assertRaises(RuntimeError) as cm: - asr_process(self.test_data, self.srate, self.state, max_mem=0.001) - - self.assertIn('Not enough memory', str(cm.exception)) - - def test_rank_deficient_covariance_produces_sane_output(self): - """Process genuinely rank-deficient data (singular covariance). - - Duplicate and zeroed channels make the per-window covariance singular, - exercising the eigendecomposition and pseudo-inverse paths with a real - degenerate input rather than monkeypatching numpy to raise. The cleaned - output must stay finite and keep its shape. - """ - degenerate = self.test_data.copy() - degenerate[3, :] = degenerate[0, :] # duplicate channel -> singular covariance - degenerate[5, :] = 0.0 # flat channel -> singular covariance - - cleaned_data, new_state = asr_process(degenerate, self.srate, self.state) - - self.assertEqual(cleaned_data.shape, degenerate.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_extreme_artifact_amplitudes_produce_sane_output(self): - """Process data with extreme-amplitude artifacts. - - Huge transient amplitudes drive the reconstruction matrix toward - ill-conditioning, exercising the same numeric path. The cleaned output - must remain finite, keep its shape, and attenuate the injected spike. - """ - extreme = self.test_data.copy() - spike_peak = float(np.max(np.abs(extreme))) * 1e4 - extreme[2, 100:150] += spike_peak - - cleaned_data, new_state = asr_process(extreme, self.srate, self.state) - - self.assertEqual(cleaned_data.shape, extreme.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - self.assertLess(float(np.max(np.abs(cleaned_data))), spike_peak) - - def test_component_selection_shape_error_propagates(self): - """A shape/contract bug during component selection must surface, not be - silently swallowed into a no-op (keep-all) for the affected window. - """ - bad_state = dict(self.state) - # T is C x C in a valid state; an incompatible threshold matrix makes - # finite_matmul(T, V) fail. This must raise rather than disable cleaning. - bad_state['T'] = np.ones((self.n_channels + 1, self.n_channels + 1)) - - with self.assertRaises(ValueError): - asr_process(self.test_data, self.srate, bad_state) - - def test_state_persistence_across_calls(self): - """Test that state is properly maintained across multiple processing calls.""" - # First call - chunk1 = self.test_data[:, :250] - cleaned1, state1 = asr_process(chunk1, self.srate, self.state) - - # Second call with updated state - chunk2 = self.test_data[:, 250:] - cleaned2, state2 = asr_process(chunk2, self.srate, state1) - - # Check that carry buffer was maintained - self.assertIsNotNone(state1['carry']) - self.assertIsNotNone(state2['carry']) - - # Check output shapes - self.assertEqual(cleaned1.shape, chunk1.shape) - self.assertEqual(cleaned2.shape, chunk2.shape) - - def test_window_length_adjustment(self): - """Test automatic window length adjustment for small datasets.""" - # Create data that would require window length adjustment - small_data = np.random.randn(self.n_channels, 50) * 0.5 - - cleaned_data, new_state = asr_process( - small_data, - self.srate, - self.state, - window_len=0.1, # Very small window - ) - - # Should complete without errors despite small data - self.assertEqual(cleaned_data.shape, small_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_component_selection_error_handling(self): - """An unexpected error during threshold computation must propagate. - - Previously such errors were swallowed and the window silently kept all - components (no artifact removal). They must now surface so genuine bugs are - visible rather than producing quietly under-cleaned data. - """ - with patch('numpy.sum', side_effect=Exception("Threshold error")): - with self.assertRaises(Exception) as cm: - asr_process(self.test_data, self.srate, self.state) - self.assertIn('Threshold error', str(cm.exception)) - - -class TestAsrIntegration(unittest.TestCase): - """Integration tests for ASR calibration and processing.""" - - def setUp(self): - """Set up integration test fixtures.""" - np.random.seed(456) - self.n_channels = 4 - self.srate = 128.0 - - # Create realistic calibration data - self.calib_data = self.create_realistic_eeg(self.n_channels, int(self.srate * 60)) - - # Create test data with artifacts - self.test_data = self.create_realistic_eeg(self.n_channels, int(self.srate * 10)) - self.add_artifacts(self.test_data) - - def create_realistic_eeg(self, n_channels, n_samples): - """Create more realistic EEG-like data.""" - data = np.random.randn(n_channels, n_samples) * 0.2 - - # Add some correlated structure between channels - for i in range(n_channels): - for j in range(i + 1, n_channels): - if np.random.rand() > 0.7: # 30% chance of correlation - correlation = 0.3 * np.random.randn() - data[j] += correlation * data[i] - - # Add some temporal autocorrelation - for i in range(n_channels): - for j in range(1, n_samples): - data[i, j] += 0.7 * data[i, j - 1] * np.random.rand() - - return data - - def add_artifacts(self, data): - """Add realistic artifacts to EEG data.""" - n_channels, n_samples = data.shape - - # Add muscle artifacts (high frequency, high amplitude) - artifact_start = n_samples // 4 - artifact_end = artifact_start + n_samples // 10 - data[1, artifact_start:artifact_end] += np.random.randn(artifact_end - artifact_start) * 3.0 - - # Add eye blink artifacts (affects frontal channels) - blink_times = [n_samples // 2, 3 * n_samples // 4] - for blink_time in blink_times: - blink_duration = 20 - if blink_time + blink_duration < n_samples: - blink_artifact = 5.0 * np.exp(-np.arange(blink_duration) / 5.0) - data[0, blink_time : blink_time + blink_duration] += blink_artifact - - def test_full_calibration_and_processing_pipeline(self): - """Test complete ASR pipeline from calibration to processing.""" - # Calibrate ASR - state = asr_calibrate(self.calib_data, self.srate) - - # Process test data - cleaned_data, final_state = asr_process(self.test_data, self.srate, state) - - # Basic checks - self.assertEqual(cleaned_data.shape, self.test_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - # Check that artifacts were reduced (RMS should be lower in artifact regions) - artifact_region = slice( - self.test_data.shape[1] // 4, self.test_data.shape[1] // 4 + self.test_data.shape[1] // 10 - ) - - original_rms = np.sqrt(np.mean(self.test_data[1, artifact_region] ** 2)) - cleaned_rms = np.sqrt(np.mean(cleaned_data[1, artifact_region] ** 2)) - - # Cleaned data should have lower RMS in artifact region - self.assertLess(cleaned_rms, original_rms) - - def test_streaming_processing_simulation(self): - """Test ASR processing in streaming mode with multiple chunks.""" - # Calibrate - state = asr_calibrate(self.calib_data, self.srate) - - # Process in chunks to simulate streaming - chunk_size = int(self.srate * 2) # 2-second chunks - n_chunks = self.test_data.shape[1] // chunk_size - - cleaned_chunks = [] - current_state = state - - for i in range(n_chunks): - start_idx = i * chunk_size - end_idx = min((i + 1) * chunk_size, self.test_data.shape[1]) - chunk = self.test_data[:, start_idx:end_idx] - - cleaned_chunk, current_state = asr_process(chunk, self.srate, current_state) - cleaned_chunks.append(cleaned_chunk) - - # Concatenate results - full_cleaned = np.concatenate(cleaned_chunks, axis=1) - - # Check results - expected_length = n_chunks * chunk_size - self.assertEqual(full_cleaned.shape[1], expected_length) - self.assertTrue(np.all(np.isfinite(full_cleaned))) - - def test_different_calibration_and_processing_parameters(self): - """Test ASR with various parameter combinations.""" - parameter_sets = [ - {'cutoff': 3.0, 'window_len': 0.25, 'max_dims': 0.5}, - {'cutoff': 7.0, 'window_len': 1.0, 'max_dims': 2}, - {'cutoff': 4.0, 'blocksize': 20, 'window_overlap': 0.8}, - ] - - for params in parameter_sets: - with self.subTest(params=params): - # Split parameters between calibration and processing - calib_params = {k: v for k, v in params.items() if k in ['cutoff', 'blocksize', 'window_overlap']} - process_params = {k: v for k, v in params.items() if k in ['window_len', 'max_dims']} - - # Test pipeline - state = asr_calibrate(self.calib_data, self.srate, **calib_params) - cleaned_data, _ = asr_process(self.test_data, self.srate, state, **process_params) - - # Should complete without errors - self.assertEqual(cleaned_data.shape, self.test_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - def test_robustness_with_challenging_data(self): - """Test ASR robustness with challenging data conditions.""" - # Test with very noisy calibration data - noisy_calib = self.calib_data + np.random.randn(*self.calib_data.shape) * 0.5 - - # Should still calibrate successfully - state = asr_calibrate(noisy_calib, self.srate) - cleaned_data, _ = asr_process(self.test_data, self.srate, state) - - self.assertEqual(cleaned_data.shape, self.test_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - # Test with data containing extreme outliers - outlier_data = self.test_data.copy() - outlier_data[2, 100:105] = 1000.0 # Extreme outliers - - cleaned_data, _ = asr_process(outlier_data, self.srate, state) - - # Should handle outliers gracefully - self.assertEqual(cleaned_data.shape, outlier_data.shape) - self.assertTrue(np.all(np.isfinite(cleaned_data))) - - # Extreme values should be reduced - max_cleaned = np.max(np.abs(cleaned_data[2, 100:105])) - max_original = np.max(np.abs(outlier_data[2, 100:105])) - self.assertLess(max_cleaned, max_original) - - -class TestAsrEdgeCases(unittest.TestCase): - """Test edge cases and error conditions for ASR functions.""" - - def test_minimum_viable_data_sizes(self): - """Test ASR with minimum viable data sizes.""" - n_channels = 2 - srate = 100.0 - - # Minimum calibration data (need enough for multiple windows) - # window_len=0.5s, overlap=0.66 means step=0.17s, need at least 2 windows - min_samples = int(srate * 1.2) # 1.2 seconds to allow multiple windows - min_data = np.random.randn(n_channels, min_samples) * 0.3 - - state = asr_calibrate(min_data, srate, window_len=0.5) - - # Should work with minimal data - self.assertIsInstance(state, dict) - self.assertIn('M', state) - - # Test processing with minimal data - test_data = np.random.randn(n_channels, 20) * 0.4 - cleaned_data, _ = asr_process(test_data, srate, state) - - self.assertEqual(cleaned_data.shape, test_data.shape) - - def test_single_channel_data(self): - """Test ASR with single channel data.""" - n_channels = 1 - srate = 200.0 - calib_data = np.random.randn(n_channels, 1000) * 0.3 - - state = asr_calibrate(calib_data, srate) - - # Should work with single channel - self.assertEqual(state['M'].shape, (1, 1)) - self.assertEqual(state['T'].shape, (1, 1)) - - # Test processing - test_data = np.random.randn(n_channels, 100) * 0.4 - cleaned_data, _ = asr_process(test_data, srate, state) - - self.assertEqual(cleaned_data.shape, test_data.shape) - - def test_very_high_sampling_rate(self): - """Very high (unsupported) sampling rate must fail loudly in calibration. - - 2000 Hz has no pre-computed spectral filter; calibration must raise rather - than silently substitute a degenerate difference filter. - """ - n_channels = 4 - srate = 2000.0 # High, unsupported sampling rate - - n_samples = int(srate * 2) # 2 seconds - calib_data = np.random.randn(n_channels, n_samples) * 0.3 - - with self.assertRaises(ValueError) as cm: - asr_calibrate(calib_data, srate) - self.assertIn('No pre-computed ASR spectral filter', str(cm.exception)) - - # With explicit filter coefficients the high rate is processable end-to-end. - B = np.array([1.0, -0.5]) - A = np.array([1.0]) - state = asr_calibrate(calib_data, srate, B=B, A=A) - test_data = np.random.randn(n_channels, 200) * 0.4 - cleaned_data, _ = asr_process(test_data, srate, state) - self.assertEqual(cleaned_data.shape, test_data.shape) - - def test_zero_variance_data(self): - """Test ASR with zero variance data.""" - n_channels = 3 - srate = 250.0 - - # Create data with zero variance in some channels - calib_data = np.random.randn(n_channels, 1000) * 0.3 - calib_data[1, :] = 1.0 # Constant channel - - # Should handle zero variance gracefully - state = asr_calibrate(calib_data, srate) - - self.assertIsInstance(state, dict) - self.assertTrue(np.all(np.isfinite(state['M']))) - - def test_memory_usage_calculation_accuracy(self): - """Test that memory usage calculations are reasonable.""" - n_channels = 8 - srate = 250.0 - n_samples = 10000 # Large dataset - - data = np.random.randn(n_channels, n_samples) * 0.3 - - # Test with different memory limits - for max_mem in [1, 10, 100]: # MB - with self.subTest(max_mem=max_mem): - state = asr_calibrate(data, srate, maxmem=max_mem) - - # Should complete regardless of memory limit - self.assertIsInstance(state, dict) - - # Test processing with memory limits - cleaned_data, _ = asr_process(data[:, :1000], srate, state, max_mem=max_mem) - self.assertEqual(cleaned_data.shape, (n_channels, 1000)) - - -class TestCleanAsrNoMutation(unittest.TestCase): - """Regression tests that clean_asr never mutates the caller's EEG.""" - - def setUp(self): - np.random.seed(7) - n_channels = 8 - n_samples = 2500 - srate = 250.0 - data = np.random.randn(n_channels, n_samples) * 0.5 - for i in range(n_channels): - for j in range(1, n_samples): - data[i, j] += 0.8 * data[i, j - 1] - # Inject a non-finite sample to exercise the in-place NaN-zeroing path - # that asr_calibrate applies to whatever array it receives. - data[0, 100] = np.nan - self.EEG = { - 'data': data, - 'srate': srate, - 'nbchan': n_channels, - 'pnts': n_samples, - 'etc': {}, - } - - def test_does_not_mutate_input_data(self): - """clean_asr must leave the caller's EEG['data'] (incl. NaNs) unchanged.""" - EEG_in = self.EEG - original_data = EEG_in['data'].copy() - - EEG_out = clean_asr(EEG_in, ref_maxbadchannels='off') - - # The caller's data is byte-for-byte unchanged, including the NaN that - # asr_calibrate would otherwise have zeroed in place. - self.assertTrue(np.array_equal(original_data, EEG_in['data'], equal_nan=True)) - # Output is a distinct object with distinct data. - self.assertIsNot(EEG_out, EEG_in) - self.assertIsNot(EEG_out['data'], EEG_in['data']) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_utils_covariance.py b/tests/test_utils_covariance.py deleted file mode 100644 index ce397cd3..00000000 --- a/tests/test_utils_covariance.py +++ /dev/null @@ -1,462 +0,0 @@ -import unittest -import numpy as np -from unittest.mock import patch - -from eegprep.plugins.clean_rawdata.private.covariance import ( - diag_nd, - cov_logm, - cov_expm, - cov_powm, - cov_sqrtm, - cov_rsqrtm, - cov_sqrtm2, - cov_mean, - cov_shrinkage, -) - - -class TestDiagNd(unittest.TestCase): - """Test the diag_nd utility function.""" - - def test_1d_input(self): - """Test diag_nd with 1D input (like np.diag).""" - x = np.array([1, 2, 3]) - result = diag_nd(x) - expected = np.diag(x) - np.testing.assert_array_equal(result, expected) - - def test_2d_input(self): - """Test diag_nd with 2D input (multiple diagonal matrices).""" - x = np.array([[1, 2], [3, 4]]) # Shape: (2, 2) - result = diag_nd(x) - - # Should create 2 diagonal matrices - expected = np.array( - [ - [[1, 0], [0, 2]], # diag([1, 2]) - [[3, 0], [0, 4]], # diag([3, 4]) - ] - ) - np.testing.assert_array_equal(result, expected) - - def test_3d_input(self): - """Test diag_nd with 3D input.""" - x = np.array([[[1, 2, 3]], [[4, 5, 6]]]) # Shape: (2, 1, 3) - result = diag_nd(x) - - # Should have shape (2, 1, 3, 3) - self.assertEqual(result.shape, (2, 1, 3, 3)) - - # Check first diagonal matrix - expected_first = np.diag([1, 2, 3]) - np.testing.assert_array_equal(result[0, 0], expected_first) - - -class TestCovarianceMatrixOperations(unittest.TestCase): - """Test matrix operations on covariance matrices.""" - - def setUp(self): - """Create test covariance matrices.""" - # Simple 2x2 positive definite matrix - self.cov_2x2 = np.array([[2.0, 1.0], [1.0, 2.0]]) - - # 3x3 positive definite matrix - self.cov_3x3 = np.array([[4.0, 1.0, 0.5], [1.0, 3.0, 1.0], [0.5, 1.0, 2.0]]) - - # Stack of covariance matrices - self.cov_stack = np.array([[[2.0, 1.0], [1.0, 2.0]], [[3.0, 0.5], [0.5, 1.5]]]) - - def test_cov_logm_single_matrix(self): - """Test matrix logarithm of single covariance matrix.""" - result = cov_logm(self.cov_2x2) - - # Verify result is symmetric - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - - # Verify expm(logm(C)) = C - reconstructed = cov_expm(result) - np.testing.assert_array_almost_equal(reconstructed, self.cov_2x2, decimal=10) - - def test_cov_expm_single_matrix(self): - """Test matrix exponential of single matrix.""" - # Start with log of a matrix, then exponentiate - log_cov = cov_logm(self.cov_2x2) - result = cov_expm(log_cov) - - # Should recover original matrix - np.testing.assert_array_almost_equal(result, self.cov_2x2, decimal=10) - - # Result should be positive definite (all eigenvalues > 0) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - def test_cov_powm_single_matrix(self): - """Test matrix power operation.""" - # Test square root (power = 0.5) - sqrt_result = cov_powm(self.cov_2x2, 0.5) - - # sqrt(C) @ sqrt(C) should equal C - reconstructed = sqrt_result @ sqrt_result - np.testing.assert_array_almost_equal(reconstructed, self.cov_2x2, decimal=10) - - # Test square (power = 2) - square_result = cov_powm(self.cov_2x2, 2.0) - expected = self.cov_2x2 @ self.cov_2x2 - np.testing.assert_array_almost_equal(square_result, expected, decimal=10) - - def test_cov_sqrtm_single_matrix(self): - """Test matrix square root.""" - result = cov_sqrtm(self.cov_2x2) - - # sqrt(C) @ sqrt(C) should equal C - reconstructed = result @ result - np.testing.assert_array_almost_equal(reconstructed, self.cov_2x2, decimal=10) - - # Result should be symmetric and positive definite - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - def test_cov_rsqrtm_single_matrix(self): - """Test matrix reciprocal square root.""" - result = cov_rsqrtm(self.cov_2x2) - - # rsqrt(C) @ C @ rsqrt(C) should equal identity - whitened = result @ self.cov_2x2 @ result - np.testing.assert_array_almost_equal(whitened, np.eye(2), decimal=10) - - # Result should be symmetric and positive definite - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - def test_cov_sqrtm2_single_matrix(self): - """Test combined square root and reciprocal square root.""" - sqrt_result, rsqrt_result = cov_sqrtm2(self.cov_2x2) - - # Compare with individual functions - expected_sqrt = cov_sqrtm(self.cov_2x2) - expected_rsqrt = cov_rsqrtm(self.cov_2x2) - - np.testing.assert_array_almost_equal(sqrt_result, expected_sqrt, decimal=10) - np.testing.assert_array_almost_equal(rsqrt_result, expected_rsqrt, decimal=10) - - # Verify relationship: sqrt @ rsqrt = identity - identity_check = sqrt_result @ rsqrt_result - np.testing.assert_array_almost_equal(identity_check, np.eye(2), decimal=10) - - def test_stack_operations(self): - """Test operations on stacks of covariance matrices.""" - # Test all operations work with stacks - log_stack = cov_logm(self.cov_stack) - exp_stack = cov_expm(log_stack) - sqrt_stack = cov_sqrtm(self.cov_stack) - rsqrt_stack = cov_rsqrtm(self.cov_stack) - pow_stack = cov_powm(self.cov_stack, 0.5) - sqrt2_stack, rsqrt2_stack = cov_sqrtm2(self.cov_stack) - - # Check shapes - self.assertEqual(log_stack.shape, self.cov_stack.shape) - self.assertEqual(exp_stack.shape, self.cov_stack.shape) - self.assertEqual(sqrt_stack.shape, self.cov_stack.shape) - self.assertEqual(rsqrt_stack.shape, self.cov_stack.shape) - - # Check round-trip: exp(log(C)) = C - np.testing.assert_array_almost_equal(exp_stack, self.cov_stack, decimal=10) - - # Check sqrt consistency - np.testing.assert_array_almost_equal(sqrt_stack, pow_stack, decimal=10) - np.testing.assert_array_almost_equal(sqrt_stack, sqrt2_stack, decimal=10) - np.testing.assert_array_almost_equal(rsqrt_stack, rsqrt2_stack, decimal=10) - - -class TestCovMean(unittest.TestCase): - """Test the covariance mean function.""" - - def setUp(self): - """Create test data.""" - # Create a stack of similar covariance matrices - self.cov_stack = np.array([[[2.0, 0.5], [0.5, 1.5]], [[2.2, 0.3], [0.3, 1.8]], [[1.8, 0.7], [0.7, 1.2]]]) - - # Single matrix (should return itself) - self.single_cov = np.array([[[3.0, 1.0], [1.0, 2.0]]]) - - def test_single_matrix_mean(self): - """Test mean of single matrix returns the matrix itself.""" - result = cov_mean(self.single_cov) - np.testing.assert_array_almost_equal(result, self.single_cov[0], decimal=10) - - def test_unweighted_mean(self): - """Test unweighted mean of covariance matrices.""" - result = cov_mean(self.cov_stack) - - # Result should be symmetric and positive definite - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - # Should be roughly in the middle of the input matrices - expected_trace = np.mean([np.trace(cov) for cov in self.cov_stack]) - actual_trace = np.trace(result) - self.assertAlmostEqual(actual_trace, expected_trace, delta=0.5) - - def test_weighted_mean(self): - """Test weighted mean of covariance matrices.""" - weights = np.array([0.5, 0.3, 0.2]) - result = cov_mean(self.cov_stack, weights=weights) - - # Result should be symmetric and positive definite - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - # Should be closer to the first matrix (highest weight) - dist_to_first = np.linalg.norm(result - self.cov_stack[0]) - dist_to_last = np.linalg.norm(result - self.cov_stack[2]) - self.assertLess(dist_to_first, dist_to_last) - - def test_robust_mean_geometric_median(self): - """Test robust mean with geometric median (huber=0).""" - result = cov_mean(self.cov_stack, robust=True, huber=0) - - # Result should be symmetric and positive definite - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - def test_robust_mean_huber(self): - """Test robust mean with Huber estimator.""" - result = cov_mean(self.cov_stack, robust=True, huber=1.0) - - # Result should be symmetric and positive definite - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - def test_convergence_parameters(self): - """Test convergence parameters (iterations and tolerance).""" - # Very loose tolerance should converge quickly - result1 = cov_mean(self.cov_stack, tol=1e-1, iters=5) - - # Tight tolerance - result2 = cov_mean(self.cov_stack, tol=1e-10, iters=100) - - # Both should be valid covariance matrices - for result in [result1, result2]: - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - def test_nancheck_functionality(self): - """Test NaN checking functionality.""" - # Create data that might cause numerical issues - problematic_cov = np.array( - [ - [[1e10, 0], [0, 1e-10]], # Very different scales - [[1e-10, 0], [0, 1e10]], - ] - ) - - # Should not raise with nancheck=False (default) - result = cov_mean(problematic_cov, nancheck=False) - self.assertFalse(np.any(np.isnan(result))) - - # Test that nancheck=True would catch NaNs if they occurred - # (We can't easily create a case that reliably produces NaNs) - with patch('numpy.any') as mock_any: - mock_any.return_value = True # Simulate NaNs detected - with self.assertRaises(RuntimeError): - cov_mean(self.cov_stack, nancheck=True) - - def test_verbose_mode(self): - """Test verbose mode output.""" - with patch('eegprep.plugins.clean_rawdata.private.covariance.logger.info') as mock_log: - cov_mean(self.cov_stack, robust=True, huber=None, verbose=True) - # Should have logged median deviations - mock_log.assert_called() - - -class TestCovShrinkage(unittest.TestCase): - """Test the covariance shrinkage function.""" - - def setUp(self): - """Create test covariance matrices.""" - self.cov_2x2 = np.array([[4.0, 2.0], [2.0, 3.0]]) - self.cov_3x3 = np.array([[5.0, 1.0, 0.5], [1.0, 4.0, 1.5], [0.5, 1.5, 3.0]]) - - # Stack of matrices - self.cov_stack = np.array([[[4.0, 2.0], [2.0, 3.0]], [[6.0, 1.0], [1.0, 2.0]]]) - - def test_no_shrinkage(self): - """Test that zero shrinkage returns original matrix.""" - result = cov_shrinkage(self.cov_2x2, shrinkage=0) - np.testing.assert_array_equal(result, self.cov_2x2) - - # Test with different targets (should all return original) - for target in ['eye', 'scaled-eye', 'diag']: - result = cov_shrinkage(self.cov_2x2, shrinkage=0, target=target) - np.testing.assert_array_equal(result, self.cov_2x2) - - def test_full_shrinkage_eye(self): - """Test full shrinkage towards identity.""" - result = cov_shrinkage(self.cov_2x2, shrinkage=1.0, target='eye') - expected = np.eye(2) - np.testing.assert_array_equal(result, expected) - - def test_full_shrinkage_scaled_eye(self): - """Test full shrinkage towards scaled identity.""" - result = cov_shrinkage(self.cov_2x2, shrinkage=1.0, target='scaled-eye') - - # Should be scaled identity with scale = trace/N - trace = np.trace(self.cov_2x2) - scale = trace / 2 - expected = scale * np.eye(2) - np.testing.assert_array_almost_equal(result, expected) - - def test_full_shrinkage_diag(self): - """Test full shrinkage towards diagonal.""" - result = cov_shrinkage(self.cov_2x2, shrinkage=1.0, target='diag') - expected = np.diag(np.diag(self.cov_2x2)) - np.testing.assert_array_equal(result, expected) - - def test_partial_shrinkage(self): - """Test partial shrinkage.""" - shrinkage = 0.3 - result = cov_shrinkage(self.cov_2x2, shrinkage=shrinkage, target='eye') - - # Should be weighted combination - expected = shrinkage * np.eye(2) + (1 - shrinkage) * self.cov_2x2 - np.testing.assert_array_almost_equal(result, expected) - - # Result should still be positive definite - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - def test_stack_shrinkage(self): - """Test shrinkage on stack of matrices.""" - shrinkage = 0.5 - result = cov_shrinkage(self.cov_stack, shrinkage=shrinkage, target='eye') - - # Check shape preserved - self.assertEqual(result.shape, self.cov_stack.shape) - - # Check each matrix individually - for i in range(len(self.cov_stack)): - expected = shrinkage * np.eye(2) + (1 - shrinkage) * self.cov_stack[i] - np.testing.assert_array_almost_equal(result[i], expected) - - def test_3d_matrices_shrinkage(self): - """Test shrinkage with 3D matrices.""" - result = cov_shrinkage(self.cov_3x3, shrinkage=0.4, target='diag') - - # Should preserve diagonal elements and shrink off-diagonal - np.testing.assert_array_almost_equal(np.diag(result), np.diag(self.cov_3x3)) - - # Off-diagonal should be shrunk - off_diag_orig = self.cov_3x3[0, 1] - off_diag_result = result[0, 1] - self.assertLess(abs(off_diag_result), abs(off_diag_orig)) - - def test_scaled_eye_with_stack(self): - """Test scaled-eye target with stack of matrices.""" - result = cov_shrinkage(self.cov_stack, shrinkage=1.0, target='scaled-eye') - - # Each matrix should be scaled identity - for i in range(len(self.cov_stack)): - trace = np.trace(self.cov_stack[i]) - scale = trace / 2 - expected = scale * np.eye(2) - np.testing.assert_array_almost_equal(result[i], expected) - - def test_invalid_target_error(self): - """Test error for invalid shrinkage target.""" - with self.assertRaises(ValueError) as cm: - cov_shrinkage(self.cov_2x2, shrinkage=0.5, target='invalid') - self.assertIn('Unsupported shrinkage target', str(cm.exception)) - - def test_symmetric_output(self): - """Test that output maintains symmetry.""" - for target in ['eye', 'scaled-eye', 'diag']: - result = cov_shrinkage(self.cov_2x2, shrinkage=0.3, target=target) - np.testing.assert_array_almost_equal(result, result.T, decimal=10) - - def test_positive_definite_preservation(self): - """Test that positive definiteness is preserved.""" - for target in ['eye', 'scaled-eye', 'diag']: - result = cov_shrinkage(self.cov_2x2, shrinkage=0.7, target=target) - eigenvals = np.linalg.eigvals(result) - self.assertTrue(np.all(eigenvals > 0)) - - -class TestEdgeCases(unittest.TestCase): - """Test edge cases and numerical stability.""" - - def test_near_singular_matrices(self): - """Test operations on near-singular matrices.""" - # Create a matrix with very small eigenvalues - near_singular = np.array([[1.0, 0.999], [0.999, 1.0]]) - - # All operations should still work - log_result = cov_logm(near_singular) - exp_result = cov_expm(log_result) - sqrt_result = cov_sqrtm(near_singular) - cov_rsqrtm(near_singular) - - # Check round-trip accuracy - np.testing.assert_array_almost_equal(exp_result, near_singular, decimal=8) - reconstructed = sqrt_result @ sqrt_result - np.testing.assert_array_almost_equal(reconstructed, near_singular, decimal=8) - - def test_identity_matrix_operations(self): - """Test operations on identity matrix.""" - identity = np.eye(3) - - # Log of identity should be zero matrix - log_result = cov_logm(identity) - np.testing.assert_array_almost_equal(log_result, np.zeros((3, 3)), decimal=10) - - # Square root of identity should be identity - sqrt_result = cov_sqrtm(identity) - np.testing.assert_array_almost_equal(sqrt_result, identity, decimal=10) - - # Reciprocal square root of identity should be identity - rsqrt_result = cov_rsqrtm(identity) - np.testing.assert_array_almost_equal(rsqrt_result, identity, decimal=10) - - def test_large_matrices(self): - """Test operations on larger matrices.""" - # Create a 10x10 positive definite matrix - np.random.seed(42) # For reproducibility - A = np.random.randn(10, 10) - large_cov = A @ A.T + np.eye(10) # Ensure positive definite - - # Test basic operations - sqrt_result = cov_sqrtm(large_cov) - reconstructed = sqrt_result @ sqrt_result - np.testing.assert_array_almost_equal(reconstructed, large_cov, decimal=8) - - # Test shrinkage - shrunk = cov_shrinkage(large_cov, shrinkage=0.1, target='eye') - eigenvals = np.linalg.eigvals(shrunk) - self.assertTrue(np.all(eigenvals > 0)) - - def test_numerical_precision_consistency(self): - """Test that operations maintain numerical precision.""" - # Use different dtypes - cov_float32 = self.cov_2x2.astype(np.float32) - cov_float64 = self.cov_2x2.astype(np.float64) - - # Results should be consistent (within precision limits) - sqrt32 = cov_sqrtm(cov_float32) - sqrt64 = cov_sqrtm(cov_float64) - - # Float32 should be close to float64 (within single precision) - np.testing.assert_array_almost_equal(sqrt32, sqrt64.astype(np.float32), decimal=6) - - def setUp(self): - """Set up test matrices.""" - self.cov_2x2 = np.array([[2.0, 1.0], [1.0, 2.0]]) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_utils_ransac.py b/tests/test_utils_ransac.py deleted file mode 100644 index 3e386ece..00000000 --- a/tests/test_utils_ransac.py +++ /dev/null @@ -1,509 +0,0 @@ -import unittest -import numpy as np -from unittest.mock import patch, MagicMock - -from eegprep.plugins.clean_rawdata.private.ransac import rand_sample, calc_projector -from eegprep.plugins.clean_rawdata.private.sphericalSplineInterpolate import sphericalSplineInterpolate - - -class TestRandSample(unittest.TestCase): - """Test the rand_sample function for random sampling without replacement.""" - - def setUp(self): - """Set up test fixtures.""" - self.rng = np.random.RandomState(42) # Fixed seed for reproducibility - - def test_basic_sampling(self): - """Test basic random sampling functionality.""" - n, m = 10, 5 - result = rand_sample(n, m, self.rng) - - # Check output shape and type - self.assertEqual(result.shape, (m,)) - self.assertEqual(result.dtype, int) - - # Check all values are within valid range - self.assertTrue(np.all(result >= 0)) - self.assertTrue(np.all(result < n)) - - # Check no duplicates (sampling without replacement) - self.assertEqual(len(np.unique(result)), m) - - def test_deterministic_with_fixed_seed(self): - """Test that results are deterministic with fixed random state.""" - n, m = 8, 4 - rng1 = np.random.RandomState(123) - rng2 = np.random.RandomState(123) - - result1 = rand_sample(n, m, rng1) - result2 = rand_sample(n, m, rng2) - - np.testing.assert_array_equal(result1, result2) - - def test_different_seeds_give_different_results(self): - """Test that different seeds produce different results.""" - n, m = 10, 5 - rng1 = np.random.RandomState(111) - rng2 = np.random.RandomState(222) - - result1 = rand_sample(n, m, rng1) - result2 = rand_sample(n, m, rng2) - - # Results should be different (very high probability) - self.assertFalse(np.array_equal(result1, result2)) - - def test_sample_all_elements(self): - """Test sampling all available elements.""" - n = 5 - m = n # Sample all elements - result = rand_sample(n, m, self.rng) - - # Should get all indices in some order - sorted_result = np.sort(result) - expected = np.arange(n) - np.testing.assert_array_equal(sorted_result, expected) - - def test_sample_one_element(self): - """Test sampling a single element.""" - n, m = 10, 1 - result = rand_sample(n, m, self.rng) - - self.assertEqual(result.shape, (1,)) - self.assertTrue(0 <= result[0] < n) - - def test_edge_case_small_pool(self): - """Test with very small pool size.""" - n, m = 2, 1 - result = rand_sample(n, m, self.rng) - - self.assertEqual(result.shape, (1,)) - self.assertIn(result[0], [0, 1]) - - def test_sampling_algorithm_coverage(self): - """Test that the sampling algorithm covers the pool properly.""" - n, m = 6, 3 - - # Run multiple times to check distribution - results = [] - for seed in range(100): - rng = np.random.RandomState(seed) - result = rand_sample(n, m, rng) - results.extend(result.tolist()) - - # Each index should appear at least once across many runs - unique_indices = set(results) - self.assertEqual(len(unique_indices), n) # All indices should appear - - -class TestCalcProjector(unittest.TestCase): - """Test the calc_projector function for RANSAC reconstruction matrices.""" - - def setUp(self): - """Set up test fixtures with synthetic channel locations.""" - # Create synthetic 3D channel locations (spherical coordinates) - self.n_channels = 8 - theta = np.linspace(0, 2 * np.pi, self.n_channels, endpoint=False) - phi = np.pi / 4 # Fixed elevation - - self.locs = np.column_stack( - [np.cos(theta) * np.cos(phi), np.sin(theta) * np.cos(phi), np.sin(phi) * np.ones(self.n_channels)] - ) - - # Test parameters - self.num_samples = 5 - self.subset_size = 4 - self.rng = np.random.RandomState(12345) - - def test_basic_projector_calculation(self): - """Test basic projector matrix calculation.""" - - # Mock the sphericalSplineInterpolate function - # Input: src_locs (3, subset_size), dest_locs (3, n_channels) - # Output: W (n_channels, subset_size) - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size (columns in transposed input) - n_dest = dest_locs.shape[1] # n_channels (columns in transposed input) - return np.random.randn(n_dest, n_src), None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ) as mock_interp: - result = calc_projector(self.locs, self.num_samples, self.subset_size, stream=self.rng) - - # Check output shape - expected_shape = (self.n_channels, self.n_channels * self.num_samples) - self.assertEqual(result.shape, expected_shape) - - # Verify interpolation function was called correct number of times - self.assertEqual(mock_interp.call_count, self.num_samples) - - def test_deterministic_with_fixed_stream(self): - """Test that results are deterministic with fixed random stream.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - return np.ones((n_dest, n_src)), None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - rng1 = np.random.RandomState(999) - rng2 = np.random.RandomState(999) - - result1 = calc_projector(self.locs, 3, 2, stream=rng1) - result2 = calc_projector(self.locs, 3, 2, stream=rng2) - - np.testing.assert_array_equal(result1, result2) - - def test_default_random_stream(self): - """Test that default random stream is used when none provided.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - # Return deterministic values for reproducible test - return np.ones((n_dest, n_src)), None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - # Call without providing stream - should use default seed - result1 = calc_projector(self.locs, 2, 2) - result2 = calc_projector(self.locs, 2, 2) - - # Should be identical due to fixed default seed - np.testing.assert_array_equal(result1, result2) - - def test_matlab_subroutine(self): - """Test using MATLAB subroutine.""" - mock_matlab = MagicMock() - - def mock_matlab_interp(src_locs, dest_locs): - n_src = src_locs.shape[1] # MATLAB uses transposed input - n_dest = dest_locs.shape[1] - return np.random.randn(n_dest, n_src), None - - mock_matlab.sphericalSplineInterpolate.side_effect = mock_matlab_interp - - with patch('eegprep.plugins.clean_rawdata.private.ransac.get_eeglab') as mock_get_eeglab: - mock_get_eeglab.return_value = mock_matlab - - result = calc_projector(self.locs, self.num_samples, self.subset_size, stream=self.rng, subroutine='matlab') - - # Check that MATLAB was requested and used - mock_get_eeglab.assert_called_once_with('MAT') - self.assertEqual(mock_matlab.sphericalSplineInterpolate.call_count, self.num_samples) - self.assertEqual(result.shape, (self.n_channels, self.n_channels * self.num_samples)) - - def test_octave_subroutine(self): - """Test using Octave subroutine.""" - mock_octave = MagicMock() - - def mock_octave_interp(src_locs, dest_locs): - n_src = src_locs.shape[1] # Octave uses transposed input - n_dest = dest_locs.shape[1] - return np.random.randn(n_dest, n_src), None - - mock_octave.sphericalSplineInterpolate.side_effect = mock_octave_interp - - with patch('eegprep.plugins.clean_rawdata.private.ransac.get_eeglab') as mock_get_eeglab: - mock_get_eeglab.return_value = mock_octave - - result = calc_projector(self.locs, self.num_samples, self.subset_size, stream=self.rng, subroutine='octave') - - # Check that Octave was requested and used - mock_get_eeglab.assert_called_once_with('OCT') - self.assertEqual(mock_octave.sphericalSplineInterpolate.call_count, self.num_samples) - self.assertEqual(result.shape, (self.n_channels, self.n_channels * self.num_samples)) - - def test_invalid_subroutine_error(self): - """Test error handling for invalid subroutine.""" - with self.assertRaises(ValueError) as cm: - calc_projector( - self.locs, self.num_samples, self.subset_size, stream=self.rng, subroutine='invalid_subroutine' - ) - - self.assertIn('Unknown subroutine: invalid_subroutine', str(cm.exception)) - - def test_different_sample_parameters(self): - """Test with different sampling parameters.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - # Return identity-like matrix truncated to correct size - result = np.zeros((n_dest, n_src)) - min_dim = min(n_dest, n_src) - result[:min_dim, :min_dim] = np.eye(min_dim) - return result, None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - # Test with different num_samples - result1 = calc_projector(self.locs, 3, 2, stream=self.rng) - result2 = calc_projector(self.locs, 6, 2, stream=self.rng) - - self.assertEqual(result1.shape, (self.n_channels, self.n_channels * 3)) - self.assertEqual(result2.shape, (self.n_channels, self.n_channels * 6)) - - # Test with different subset_size - result3 = calc_projector(self.locs, 4, 3, stream=self.rng) - self.assertEqual(result3.shape, (self.n_channels, self.n_channels * 4)) - - def test_complex_interpolation_result_handling(self): - """Test handling of complex interpolation results.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - # Create complex interpolation result - real_part = np.random.randn(n_dest, n_src) - imag_part = np.random.randn(n_dest, n_src) - return real_part + 1j * imag_part, None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - result = calc_projector(self.locs, 2, 2, stream=self.rng) - - # Result should be real (np.real applied) - self.assertTrue(np.isrealobj(result)) - self.assertFalse(np.iscomplexobj(result)) - - def test_sampling_coverage_across_channels(self): - """Test that sampling covers different channel subsets.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - result = np.zeros((n_dest, n_src)) - min_dim = min(n_dest, n_src) - result[:min_dim, :min_dim] = np.eye(min_dim) - return result, None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ) as mock_interp: - # Use many samples to ensure different subsets are chosen - result = calc_projector(self.locs, num_samples=20, subset_size=3, stream=np.random.RandomState(777)) - - # Check that result has expected shape - expected_shape = (self.n_channels, self.n_channels * 20) - self.assertEqual(result.shape, expected_shape) - - # Verify interpolation was called for each sample - self.assertEqual(mock_interp.call_count, 20) - - def test_edge_case_single_sample(self): - """Test with single sample.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - return np.random.randn(n_dest, n_src), None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ) as mock_interp: - result = calc_projector(self.locs, 1, 2, stream=self.rng) - - self.assertEqual(result.shape, (self.n_channels, self.n_channels)) - self.assertEqual(mock_interp.call_count, 1) - - def test_large_subset_size(self): - """Test with subset size close to total number of channels.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - return np.random.randn(n_dest, n_src), None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ) as mock_interp: - # Use subset_size = n_channels - 1 - result = calc_projector(self.locs, 3, self.n_channels - 1, stream=self.rng) - - self.assertEqual(result.shape, (self.n_channels, self.n_channels * 3)) - self.assertEqual(mock_interp.call_count, 3) - - def test_real_interpolation_channel_mapping_and_assembly(self): - """Validate calc_projector against the real spherical-spline kernel. - - The mock-based tests above only check output shape and call counts, so a - bug that permutes channel indices or mishandles the per-subset transpose - would pass. This test runs the real interpolation on a small montage and - independently reproduces each subset's reconstruction matrix, catching - such assembly bugs. - """ - num_samples = 4 - subset_size = self.n_channels - 2 - - # Reproduce the exact subsets calc_projector samples (k from num_samples-1..0). - subset_stream = np.random.RandomState(7) - subsets = {k: rand_sample(self.n_channels, subset_size, subset_stream) for k in range(num_samples - 1, -1, -1)} - - projector = calc_projector(self.locs, num_samples, subset_size, stream=np.random.RandomState(7)) - - # Output must be a finite, real-valued bag of reconstruction matrices. - self.assertEqual(projector.shape, (self.n_channels, self.n_channels * num_samples)) - self.assertTrue(np.isrealobj(projector)) - self.assertTrue(np.all(np.isfinite(projector))) - - blocks = projector.reshape(self.n_channels, num_samples, self.n_channels) - for k, sample in subsets.items(): - block = blocks[:, k, :] - - # Only the rows of the sampled source channels carry weight; the two - # unsampled channels must stay all-zero. A channel-index permutation - # would shift the zero rows away from the unsampled channels. - nonzero_rows = np.flatnonzero(np.any(block != 0, axis=1)) - np.testing.assert_array_equal(np.sort(nonzero_rows), np.sort(sample)) - - # The non-zero rows must equal the real spherical-spline weights for - # this subset, transposed exactly as calc_projector assembles them. - expected_w = sphericalSplineInterpolate(self.locs[sample, :].T, self.locs.T)[0] - np.testing.assert_allclose(block[sample, :], np.real(expected_w).T, rtol=1e-10, atol=1e-12) - - -class TestRansacIntegration(unittest.TestCase): - """Integration tests for RANSAC functionality.""" - - def setUp(self): - """Set up integration test fixtures.""" - # Create more realistic channel locations - self.n_channels = 16 - - # Create locations on unit sphere (typical EEG setup) - np.random.seed(42) - self.locs = np.random.randn(self.n_channels, 3) - # Normalize to unit sphere - norms = np.linalg.norm(self.locs, axis=1, keepdims=True) - self.locs = self.locs / norms - - def test_no_fail_path_with_realistic_data(self): - """Test that RANSAC functions don't fail with realistic data.""" - - # Mock the interpolation to avoid dependency on complex spatial functions - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - return np.random.randn(n_dest, n_src), None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - # This should not raise any exceptions - result = calc_projector(self.locs, num_samples=10, subset_size=8, stream=np.random.RandomState(555)) - - # Basic sanity checks - self.assertIsInstance(result, np.ndarray) - self.assertEqual(result.shape, (self.n_channels, self.n_channels * 10)) - self.assertFalse(np.any(np.isnan(result))) - self.assertFalse(np.any(np.isinf(result))) - - def test_synthetic_noisy_channel_detection_simulation(self): - """Simulate bad channel detection scenario.""" - # Create synthetic data where some channels are "bad" - n_good_channels = 12 - n_bad_channels = 4 - total_channels = n_good_channels + n_bad_channels - - # Create locations - locs = np.random.randn(total_channels, 3) - locs = locs / np.linalg.norm(locs, axis=1, keepdims=True) - - # Mock interpolation that simulates good reconstruction for good channels - def mock_interp_func(src_locs, dest_locs): - # Create a reconstruction matrix that works well for "good" channels - n_src = src_locs.shape[1] # Transposed input - n_dest = dest_locs.shape[1] - - # Simulate good reconstruction (identity-like for good channels) - result = np.random.randn(n_dest, n_src) * 0.1 - # Add some structure to simulate realistic interpolation - if n_src == n_dest: - result += np.eye(n_dest, n_src) - - return result, None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - # Calculate projector matrix - projector = calc_projector(locs, num_samples=15, subset_size=10, stream=np.random.RandomState(888)) - - # Verify basic properties - self.assertEqual(projector.shape, (total_channels, total_channels * 15)) - - # In a real RANSAC application, this projector would be used to: - # 1. Reconstruct each channel from subsets - # 2. Compute reconstruction errors - # 3. Identify channels with consistently high errors as "bad" - - # For testing purposes, just verify the projector is reasonable - self.assertFalse(np.any(np.isnan(projector))) - self.assertFalse(np.any(np.isinf(projector))) - - # Check that projector has some non-zero structure - self.assertTrue(np.any(projector != 0)) - - def test_deterministic_behavior_for_reproducibility(self): - """Test that RANSAC behavior is reproducible for debugging.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - # Use fixed seed for reproducible results - rng = np.random.RandomState(123) - return rng.randn(n_dest, n_src), None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - # Multiple runs with same seed should give identical results - results = [] - for _ in range(3): - result = calc_projector( - self.locs, - num_samples=5, - subset_size=6, - stream=np.random.RandomState(999), # Same seed each time - ) - results.append(result.copy()) - - # All results should be identical - for i in range(1, len(results)): - np.testing.assert_array_equal(results[0], results[i]) - - def test_parameter_validation_implicit(self): - """Test that functions handle edge cases gracefully.""" - - def mock_interp_func(src_locs, dest_locs): - n_src = src_locs.shape[1] # subset_size - n_dest = dest_locs.shape[1] # n_channels - result = np.zeros((n_dest, n_src)) - min_dim = min(n_dest, n_src) - result[:min_dim, :min_dim] = np.eye(min_dim) - return result, None - - with patch( - 'eegprep.plugins.clean_rawdata.private.ransac.sphericalSplineInterpolate', side_effect=mock_interp_func - ): - # Test minimum viable parameters - result = calc_projector(self.locs, num_samples=1, subset_size=1, stream=np.random.RandomState(111)) - - self.assertEqual(result.shape, (self.n_channels, self.n_channels)) - - # Test with subset_size equal to number of channels - result = calc_projector( - self.locs, num_samples=2, subset_size=self.n_channels, stream=np.random.RandomState(222) - ) - - self.assertEqual(result.shape, (self.n_channels, self.n_channels * 2)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tools/visual_parity/visual_capture.py b/tools/visual_parity/visual_capture.py index 3635f334..6c82ec61 100644 --- a/tools/visual_parity/visual_capture.py +++ b/tools/visual_parity/visual_capture.py @@ -72,11 +72,6 @@ from eegprep.functions.studyfunc.pop_study import pop_study_dialog_spec from eegprep.functions.studyfunc.pop_studydesign import pop_studydesign_dialog_spec from eegprep.functions.studyfunc.std_checkset import std_checkset -from eegprep.plugins.ICLabel.pop_icflag import pop_icflag_dialog_spec -from eegprep.plugins.ICLabel.pop_iclabel import pop_iclabel_dialog_spec -from eegprep.plugins.ICLabel.pop_prop_extended import pop_prop_extended -from eegprep.plugins.ICLabel.pop_viewprops import pop_viewprops_dialog_spec -from eegprep.plugins.clean_rawdata.pop_clean_rawdata import pop_clean_rawdata_dialog_spec from eegprep.plugins.dipfit.pop_dipfit_gridsearch import pop_dipfit_gridsearch_dialog_spec from eegprep.plugins.dipfit.pop_dipfit_headmodel import pop_dipfit_headmodel_dialog_spec from eegprep.plugins.dipfit.pop_dipfit_loreta import pop_dipfit_loreta_dialog_spec @@ -1103,36 +1098,7 @@ def capture_pop_runica_multiple_dialog(output: pathlib.Path) -> None: _grab_dialog(dialog, output, app) -def capture_pop_iclabel_dialog(output: pathlib.Path) -> None: - """Render and capture the pop_iclabel dialog.""" - spec = pop_iclabel_dialog_spec() renderer = QtDialogRenderer() - app, dialog, _widgets = renderer.build_dialog(spec) - _grab_dialog(dialog, output, app) - - -def capture_pop_icflag_dialog(output: pathlib.Path) -> None: - """Render and capture the pop_icflag dialog.""" - spec = pop_icflag_dialog_spec() - renderer = QtDialogRenderer() - app, dialog, _widgets = renderer.build_dialog(spec) - _grab_dialog(dialog, output, app) - - -def capture_pop_prop_extended_dashboard(output: pathlib.Path) -> None: - """Render and capture the ICLabel extended property dashboard.""" - figure = pop_prop_extended( - _demo_iclabel_dashboard_eeg(), - 0, - [1, 2], - spec_opt="'freqrange', [2 40]", - scroll_event=1, - ) - figure.savefig(output, dpi=200) - plt.close(figure) - - -def capture_pop_subcomp_dialog(output: pathlib.Path) -> None: """Render and capture the pop_subcomp dialog.""" eeg = _demo_main_eeg() eeg["reject"] = {"gcompreject": np.zeros(4, dtype=int)} @@ -1166,7 +1132,6 @@ def _rejection_spec(case_id: str): return pop_autorej_dialog_spec(eeg) if case_id == "pop_selectcomps_dialog": return pop_selectcomps_dialog_spec(eeg) - if case_id == "pop_viewprops_dialog": eeg["etc"] = { "ic_classification": { "ICLabel": { @@ -1175,7 +1140,6 @@ def _rejection_spec(case_id: str): } } } - return pop_viewprops_dialog_spec(eeg, 0) if case_id == "pop_rejchan_dialog": return pop_rejchan_dialog_spec(continuous) if case_id == "pop_rejcont_dialog": @@ -1242,10 +1206,7 @@ def capture_dipfit_dialog(output: pathlib.Path, *, case_id: str) -> None: _grab_dialog(dialog, output, app) -def capture_pop_clean_rawdata_dialog(output: pathlib.Path) -> None: - """Render and capture the pop_clean_rawdata dialog.""" eeg = _demo_main_eeg() - spec = pop_clean_rawdata_dialog_spec(eeg) renderer = QtDialogRenderer() app, dialog, _widgets = renderer.build_dialog(spec) _grab_dialog(dialog, output, app) @@ -1428,11 +1389,7 @@ def _capture_case_handlers() -> dict[str, CaptureHandler]: "pop_eventstat_dialog": capture_pop_eventstat_dialog, "pop_runica_dialog": capture_pop_runica_dialog, "pop_runica_multiple_dialog": capture_pop_runica_multiple_dialog, - "pop_iclabel_dialog": capture_pop_iclabel_dialog, - "pop_icflag_dialog": capture_pop_icflag_dialog, - "iclabel_pop_prop_extended_dashboard": capture_pop_prop_extended_dashboard, "pop_subcomp_dialog": capture_pop_subcomp_dialog, - "pop_clean_rawdata_dialog": capture_pop_clean_rawdata_dialog, "pop_chansel_dialog": capture_pop_chansel_dialog, "select_multiple_datasets_dialog": capture_select_multiple_datasets_dialog, "pop_interp_dataset_index_dialog": capture_dataset_index_dialog, @@ -1473,7 +1430,6 @@ def _capture_case_handlers() -> dict[str, CaptureHandler]: "pop_rejspec_dialog", "pop_rejtrend_dialog", "pop_selectcomps_dialog", - "pop_viewprops_dialog", ) } ) diff --git a/uv.lock b/uv.lock index c4068d0d..909d2f57 100644 --- a/uv.lock +++ b/uv.lock @@ -458,33 +458,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/6e/e90527eef33f309beb811cf7c982c3aeffcce8e3edb178baa4ca3ae4a6fa/cryptography-48.0.0-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f5333311663ea94f75dd408665686aaf426563556bb5283554a3539177e03b8c", size = 4690433, upload-time = "2026-05-04T22:57:40.373Z" }, { url = "https://files.pythonhosted.org/packages/90/04/673510ed51ddff56575f306cf1617d80411ee76831ccd3097599140efdfe/cryptography-48.0.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7995ef305d7165c3f11ae07f2517e5a4f1d5c18da1376a0a9ed496336b69e5f3", size = 4710620, upload-time = "2026-05-04T22:57:42.935Z" }, { url = "https://files.pythonhosted.org/packages/14/d5/e9c4ef932c8d800490c34d8bd589d64a31d5890e27ec9e9ad532be893294/cryptography-48.0.0-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:40ba1f85eaa6959837b1d51c9767e230e14612eea4ef110ee8854ada22da1bf5", size = 4696283, upload-time = "2026-05-04T22:57:45.294Z" }, - { url = "https://files.pythonhosted.org/packages/0c/29/174b9dfb60b12d59ecfc6cfa04bc88c21b42a54f01b8aae09bb6e51e4c7f/cryptography-48.0.0-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:369a6348999f94bbd53435c894377b20ab95f25a9065c283570e70150d8abc3c", size = 5296573, upload-time = "2026-05-04T22:57:47.933Z" }, { url = "https://files.pythonhosted.org/packages/95/38/0d29a6fd7d0d1373f0c0c88a04ba20e359b257753ac497564cd660fc1d55/cryptography-48.0.0-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a0e692c683f4df67815a2d258b324e66f4738bd7a96a218c826dce4f4bd05d8f", size = 4743677, upload-time = "2026-05-04T22:57:50.067Z" }, { url = "https://files.pythonhosted.org/packages/30/be/eef653013d5c63b6a490529e0316f9ac14a37602965d4903efed1399f32b/cryptography-48.0.0-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:18349bbc56f4743c8b12dc32e2bccb2cf83ee8b69a3bba74ef8ae857e26b3d25", size = 4330808, upload-time = "2026-05-04T22:57:52.301Z" }, { url = "https://files.pythonhosted.org/packages/84/9e/500463e87abb7a0a0f9f256ec21123ecde0a7b5541a15e840ea54551fd81/cryptography-48.0.0-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e8eac43dfca5c4cccc6dad9a80504436fca53bb9bc3100a2386d730fbe6b602", size = 4695941, upload-time = "2026-05-04T22:57:54.603Z" }, - { url = "https://files.pythonhosted.org/packages/e3/dc/7303087450c2ec9e7fbb750e17c2abfbc658f23cbd0e54009509b7cc4091/cryptography-48.0.0-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9ccdac7d40688ecb5a3b4a604b8a88c8002e3442d6c60aead1db2a89a041560c", size = 5252579, upload-time = "2026-05-04T22:57:57.207Z" }, { url = "https://files.pythonhosted.org/packages/d0/c0/7101d3b7215edcdc90c45da544961fd8ed2d6448f77577460fa75a8443f7/cryptography-48.0.0-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:bd72e68b06bb1e96913f97dd4901119bc17f39d4586a5adf2d3e47bc2b9d58b5", size = 4743326, upload-time = "2026-05-04T22:57:59.535Z" }, { url = "https://files.pythonhosted.org/packages/ac/d8/5b833bad13016f562ab9d063d68199a4bd121d18458e439515601d3357ec/cryptography-48.0.0-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:59baa2cb386c4f0b9905bd6eb4c2a79a69a128408fd31d32ca4d7102d4156321", size = 4826672, upload-time = "2026-05-04T22:58:01.996Z" }, { url = "https://files.pythonhosted.org/packages/98/e1/7074eb8bf3c135558c73fc2bcf0f5633f912e6fb87e868a55c454080ef09/cryptography-48.0.0-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:9249e3cd978541d665967ac2cb2787fd6a62bddf1e75b3e347a594d7dacf4f74", size = 4972574, upload-time = "2026-05-04T22:58:03.968Z" }, { url = "https://files.pythonhosted.org/packages/89/6e/18e07a618bb5442ba10cf4df16e99c071365528aa570dfcb8c02e25a303b/cryptography-48.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c7378637d7d88016fa6791c159f698b3d3eed28ebf844ac36b9dc04a14dae18", size = 4684776, upload-time = "2026-05-04T22:58:13.712Z" }, { url = "https://files.pythonhosted.org/packages/be/6a/4ea3b4c6c6759794d5ee2103c304a5076dc4b19ae1f9fe47dba439e159e9/cryptography-48.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc90c0b39b2e3c65ef52c804b72e3c58f8a04ab2a1871272798e5f9572c17d20", size = 4698121, upload-time = "2026-05-04T22:58:16.448Z" }, { url = "https://files.pythonhosted.org/packages/2f/59/6ff6ad6cae03bb887da2a5860b2c9805f8dac969ef01ce563336c49bd1d1/cryptography-48.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:76341972e1eff8b4bea859f09c0d3e64b96ce931b084f9b9b7db8ef364c30eff", size = 4690042, upload-time = "2026-05-04T22:58:18.544Z" }, - { url = "https://files.pythonhosted.org/packages/ca/b4/fc334ed8cfd705aca282fe4d8f5ae64a8e0f74932e9feecb344610cf6e4d/cryptography-48.0.0-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:55b7718303bf06a5753dcdccf2f3945cf18ad7bffde41b61226e4db31ab89a9c", size = 5282526, upload-time = "2026-05-04T22:58:20.75Z" }, { url = "https://files.pythonhosted.org/packages/11/08/9f8c5386cc4cd90d8255c7cdd0f5baf459a08502a09de30dc51f553d38dc/cryptography-48.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:a64697c641c7b1b2178e573cbc31c7c6684cd56883a478d75143dbb7118036db", size = 4733116, upload-time = "2026-05-04T22:58:23.627Z" }, { url = "https://files.pythonhosted.org/packages/b8/77/99307d7574045699f8805aa500fa0fb83422d115b5400a064ddd306d7750/cryptography-48.0.0-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:561215ea3879cb1cbbf272867e2efda62476f240fb58c64de6b393ae19246741", size = 4316030, upload-time = "2026-05-04T22:58:25.581Z" }, { url = "https://files.pythonhosted.org/packages/fd/36/a608b98337af3cb2aff4818e406649d30572b7031918b04c87d979495348/cryptography-48.0.0-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:ad64688338ed4bc1a6618076ba75fd7194a5f1797ac60b47afe926285adb3166", size = 4689640, upload-time = "2026-05-04T22:58:27.747Z" }, - { url = "https://files.pythonhosted.org/packages/dd/a6/825010a291b4438aecc1f568bc428189fc1175515223632477c07dc0a6df/cryptography-48.0.0-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:906cbf0670286c6e0044156bc7d4af9cbb0ef6db9f73e52c3ec56ba6bdde5336", size = 5237657, upload-time = "2026-05-04T22:58:29.848Z" }, { url = "https://files.pythonhosted.org/packages/b9/09/4e76a09b4caa29aad535ddc806f5d4c5d01885bd978bd984fbc6ca032cae/cryptography-48.0.0-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:ea8990436d914540a40ab24b6a77c0969695ed52f4a4874c5137ccf7045a7057", size = 4732362, upload-time = "2026-05-04T22:58:32.009Z" }, { url = "https://files.pythonhosted.org/packages/18/78/444fa04a77d0cb95f417dda20d450e13c56ba8e5220fc892a1658f44f882/cryptography-48.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c18684a7f0cc9a3cb60328f496b8e3372def7c5d2df39ac267878b05565aaaae", size = 4819580, upload-time = "2026-05-04T22:58:34.254Z" }, { url = "https://files.pythonhosted.org/packages/38/85/ea67067c70a1fd4be2c63d35eeed82658023021affccc7b17705f8527dd2/cryptography-48.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9be5aafa5736574f8f15f262adc81b2a9869e2cfe9014d52a44633905b40d52c", size = 4963283, upload-time = "2026-05-04T22:58:36.376Z" }, { url = "https://files.pythonhosted.org/packages/d5/ac/f5b5995b87770c693e2596559ffafe195b4033a57f14a82268a2842953f3/cryptography-48.0.0-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:614d0949f4790582d2cc25553abd09dd723025f0c0e7c67376a1d77196743d6e", size = 4683266, upload-time = "2026-05-04T22:58:46.064Z" }, { url = "https://files.pythonhosted.org/packages/ec/c6/8b14f67e18338fbc4adb76f66c001f5c3610b3e2d1837f268f47a347dbbb/cryptography-48.0.0-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7ce4bfae76319a532a2dc68f82cc32f5676ee792a983187dac07183690e5c66f", size = 4696228, upload-time = "2026-05-04T22:58:48.22Z" }, { url = "https://files.pythonhosted.org/packages/ea/73/f808fbae9514bd91b47875b003f13e284c8c6bdfd904b7944e803937eec1/cryptography-48.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:2eb992bbd4661238c5a397594c83f5b4dc2bc5b848c365c8f991b6780efcc5c7", size = 4689097, upload-time = "2026-05-04T22:58:50.9Z" }, - { url = "https://files.pythonhosted.org/packages/93/01/d86632d7d28db8ae83221995752eeb6639ffb374c2d22955648cf8d52797/cryptography-48.0.0-cp39-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:22a5cb272895dce158b2cacdfdc3debd299019659f42947dbdac6f32d68fe832", size = 5283582, upload-time = "2026-05-04T22:58:53.017Z" }, { url = "https://files.pythonhosted.org/packages/02/e1/50edc7a50334807cc4791fc4a0ce7468b4a1416d9138eab358bfc9a3d70b/cryptography-48.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:2b4d59804e8408e2fea7d1fbaf218e5ec984325221db76e6a241a9abd6cdd95c", size = 4730479, upload-time = "2026-05-04T22:58:55.611Z" }, { url = "https://files.pythonhosted.org/packages/6f/af/99a582b1b1641ff5911ac559beb45097cf79efd4ead4657f578ef1af2d47/cryptography-48.0.0-cp39-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:984a20b0f62a26f48a3396c72e4bc34c66e356d356bf370053066b3b6d54634a", size = 4326481, upload-time = "2026-05-04T22:58:57.607Z" }, { url = "https://files.pythonhosted.org/packages/90/ee/89aa26a06ef0a7d7611788ffd571a7c50e368cc6a4d5eef8b4884e866edb/cryptography-48.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5a5ed8fde7a1d09376ca0b40e68cd59c69fe23b1f9768bd5824f54681626032a", size = 4688713, upload-time = "2026-05-04T22:59:00.077Z" }, - { url = "https://files.pythonhosted.org/packages/70/ba/bcb1b0bb7a33d4c7c0c4d4c7874b4a62ae4f56113a5f4baefa362dfb1f0f/cryptography-48.0.0-cp39-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:8cd666227ef7af430aa5914a9910e0ddd703e75f039cef0825cd0da71b6b711a", size = 5238165, upload-time = "2026-05-04T22:59:02.317Z" }, { url = "https://files.pythonhosted.org/packages/c9/70/ca4003b1ce5ca3dc3186ada51908c8a9b9ff7d5cab83cc0d43ee14ec144f/cryptography-48.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9071196d81abc88b3516ac8cdfad32e2b66dd4a5393a8e68a961e9161ddc6239", size = 4729947, upload-time = "2026-05-04T22:59:05.255Z" }, { url = "https://files.pythonhosted.org/packages/44/a0/4ec7cf774207905aef1a8d11c3750d5a1db805eb380ee4e16df317870128/cryptography-48.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1e2d54c8be6152856a36f0882ab231e70f8ec7f14e93cf87db8a2ed056bf160c", size = 4822059, upload-time = "2026-05-04T22:59:07.802Z" }, { url = "https://files.pythonhosted.org/packages/1e/75/a2e55f99c16fcac7b5d6c1eb19ad8e00799854d6be5ca845f9259eae1681/cryptography-48.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a5da777e32ffed6f85a7b2b3f7c5cbc88c146bfcd0a1d7baf5fcc6c52ee35dd4", size = 4960575, upload-time = "2026-05-04T22:59:09.851Z" }, @@ -494,76 +488,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/e1/48cedb2fe63626e91ded1edad159e2a4fb8b6906c4425eb7749673077ce7/cryptography-48.0.0-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:4defde8685ae324a9eb9d818717e93b4638ef67070ac9bc15b8ca85f63048355", size = 4666800, upload-time = "2026-05-04T22:59:27.474Z" }, ] -[[package]] -name = "cuda-bindings" -version = "13.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cuda-pathfinder", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/1a/fe/7351d7e586a8b4c9f89731bfe4cf0148223e8f9903ff09571f78b3fb0682/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b395f79cb89ce0cd8effff07c4a1e20101b873c256a1aeb286e8fd7bd0f556", size = 5744254, upload-time = "2026-03-11T00:12:29.798Z" }, - { url = "https://files.pythonhosted.org/packages/aa/ef/184aa775e970fc089942cd9ec6302e6e44679d4c14549c6a7ea45bf7f798/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6f3682ec3c4769326aafc67c2ba669d97d688d0b7e63e659d36d2f8b72f32d6", size = 6329075, upload-time = "2026-03-11T00:12:32.319Z" }, - { url = "https://files.pythonhosted.org/packages/e0/a9/3a8241c6e19483ac1f1dcf5c10238205dcb8a6e9d0d4d4709240dff28ff4/cuda_bindings-13.2.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:721104c603f059780d287969be3d194a18d0cc3b713ed9049065a1107706759d", size = 5730273, upload-time = "2026-03-11T00:12:37.18Z" }, - { url = "https://files.pythonhosted.org/packages/e9/94/2748597f47bb1600cd466b20cab4159f1530a3a33fe7f70fee199b3abb9e/cuda_bindings-13.2.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1eba9504ac70667dd48313395fe05157518fd6371b532790e96fbb31bbb5a5e1", size = 6313924, upload-time = "2026-03-11T00:12:39.462Z" }, - { url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404, upload-time = "2026-03-11T00:12:44.041Z" }, - { url = "https://files.pythonhosted.org/packages/1f/92/f899f7bbb5617bb65ec52a6eac1e9a1447a86b916c4194f8a5001b8cde0c/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46d8776a55d6d5da9dd6e9858fba2efcda2abe6743871dee47dd06eb8cb6d955", size = 6320619, upload-time = "2026-03-11T00:12:45.939Z" }, - { url = "https://files.pythonhosted.org/packages/df/93/eef988860a3ca985f82c4f3174fc0cdd94e07331ba9a92e8e064c260337f/cuda_bindings-13.2.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6629ca2df6f795b784752409bcaedbd22a7a651b74b56a165ebc0c9dcbd504d0", size = 5614610, upload-time = "2026-03-11T00:12:50.337Z" }, - { url = "https://files.pythonhosted.org/packages/18/23/6db3aba46864aee357ab2415135b3fe3da7e9f1fa0221fa2a86a5968099c/cuda_bindings-13.2.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dca0da053d3b4cc4869eff49c61c03f3c5dbaa0bcd712317a358d5b8f3f385d", size = 6149914, upload-time = "2026-03-11T00:12:52.374Z" }, - { url = "https://files.pythonhosted.org/packages/c0/87/87a014f045b77c6de5c8527b0757fe644417b184e5367db977236a141602/cuda_bindings-13.2.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a6464b30f46692d6c7f65d4a0e0450d81dd29de3afc1bb515653973d01c2cd6e", size = 5685673, upload-time = "2026-03-11T00:12:56.371Z" }, - { url = "https://files.pythonhosted.org/packages/ee/5e/c0fe77a73aaefd3fff25ffaccaac69c5a63eafdf8b9a4c476626ef0ac703/cuda_bindings-13.2.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f4af9f3e1be603fa12d5ad6cfca7844c9d230befa9792b5abdf7dd79979c3626", size = 6191386, upload-time = "2026-03-11T00:12:58.965Z" }, - { url = "https://files.pythonhosted.org/packages/5f/58/ed2c3b39c8dd5f96aa7a4abef0d47a73932c7a988e30f5fa428f00ed0da1/cuda_bindings-13.2.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df850a1ff8ce1b3385257b08e47b70e959932f5f432d0a4e46a355962b4e4771", size = 5507469, upload-time = "2026-03-11T00:13:04.063Z" }, - { url = "https://files.pythonhosted.org/packages/1f/01/0c941b112ceeb21439b05895eace78ca1aa2eaaf695c8521a068fd9b4c00/cuda_bindings-13.2.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8a16384c6494e5485f39314b0b4afb04bee48d49edb16d5d8593fd35bbd231b", size = 6059693, upload-time = "2026-03-11T00:13:06.003Z" }, -] - -[[package]] -name = "cuda-pathfinder" -version = "1.5.4" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/11/d0/c177e29701cf1d3008d7d2b16b5fc626592ce13bd535f8795c5f57187e0e/cuda_pathfinder-1.5.4-py3-none-any.whl", hash = "sha256:9563d3175ce1828531acf4b94e1c1c7d67208c347ca002493e2654878b26f4b7", size = 51657, upload-time = "2026-04-27T22:42:07.712Z" }, -] - -[[package]] -name = "cuda-toolkit" -version = "13.0.2" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/57/b2/453099f5f3b698d7d0eab38916aac44c7f76229f451709e2eb9db6615dcd/cuda_toolkit-13.0.2-py2.py3-none-any.whl", hash = "sha256:b198824cf2f54003f50d64ada3a0f184b42ca0846c1c94192fa269ecd97a66eb", size = 2364, upload-time = "2025-12-19T23:24:07.328Z" }, -] - -[package.optional-dependencies] -cudart = [ - { name = "nvidia-cuda-runtime", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -cufft = [ - { name = "nvidia-cufft", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -cufile = [ - { name = "nvidia-cufile", marker = "sys_platform == 'linux'" }, -] -cupti = [ - { name = "nvidia-cuda-cupti", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -curand = [ - { name = "nvidia-curand", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -cusolver = [ - { name = "nvidia-cusolver", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -cusparse = [ - { name = "nvidia-cusparse", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -nvjitlink = [ - { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -nvrtc = [ - { name = "nvidia-cuda-nvrtc", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] -nvtx = [ - { name = "nvidia-nvtx", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" }, -] - [[package]] name = "cycler" version = "0.12.1" @@ -681,7 +605,6 @@ all = [ { name = "sphinx-gallery" }, { name = "sphinx-togglebutton" }, { name = "sphinxcontrib-spelling" }, - { name = "torch" }, ] console = [ { name = "ipython", version = "8.39.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -708,9 +631,6 @@ gui = [ { name = "pyqtgraph" }, { name = "pyside6" }, ] -torch = [ - { name = "torch" }, -] [package.dev-dependencies] dev = [ @@ -736,7 +656,6 @@ requires-dist = [ { name = "eegprep", extras = ["docs"], marker = "extra == 'all'" }, { name = "eegprep", extras = ["gui"], marker = "extra == 'all'" }, { name = "eegprep", extras = ["gui"], marker = "extra == 'console'" }, - { name = "eegprep", extras = ["torch"], marker = "extra == 'all'" }, { name = "h5py", specifier = ">=3.3.0" }, { name = "ipython", marker = "extra == 'console'", specifier = ">=8.0" }, { name = "matplotlib", specifier = ">=3.4.0" }, @@ -764,9 +683,8 @@ requires-dist = [ { name = "sphinxcontrib-spelling", marker = "extra == 'docs'", specifier = ">=7.1.0" }, { name = "sympy", specifier = ">=1.14.0" }, { name = "threadpoolctl", specifier = ">=3.6.0" }, - { name = "torch", marker = "extra == 'torch'", specifier = ">=2.0" }, ] -provides-extras = ["torch", "gui", "console", "docs", "all"] +provides-extras = ["gui", "console", "docs", "all"] [package.metadata.requires-dev] dev = [ @@ -804,15 +722,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] -[[package]] -name = "filelock" -version = "3.29.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b5/fe/997687a931ab51049acce6fa1f23e8f01216374ea81374ddee763c493db5/filelock-3.29.0.tar.gz", hash = "sha256:69974355e960702e789734cb4871f884ea6fe50bd8404051a3530bc07809cf90", size = 57571, upload-time = "2026-04-19T15:39:10.068Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" }, -] - [[package]] name = "fonttools" version = "4.62.1" @@ -935,7 +844,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/d6/b3db928fc329b1b19ba32ffe143d2305f3aaafc583f5e1074c74ec445189/greenlet-3.2.5-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:34cc7cf8ab6f4b85298b01e13e881265ee7b3c1daf6bc10a2944abc15d4f87c3", size = 275803, upload-time = "2026-02-20T20:06:42.541Z" }, { url = "https://files.pythonhosted.org/packages/b3/ff/ab0ad4ff3d9e1faa266de4f6c79763b33fccd9265995f2940192494cc0ec/greenlet-3.2.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c11fe0cfb0ce33132f0b5d27eeadd1954976a82e5e9b60909ec2c4b884a55382", size = 633556, upload-time = "2026-02-20T20:30:41.594Z" }, { url = "https://files.pythonhosted.org/packages/da/dd/7b3ac77099a1671af8077ecedb12c9a1be1310e4c35bb69fd34c18ab6093/greenlet-3.2.5-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a145f4b1c4ed7a2c94561b7f18b4beec3d3fb6f0580db22f7ed1d544e0620b34", size = 644943, upload-time = "2026-02-20T20:37:23.084Z" }, - { url = "https://files.pythonhosted.org/packages/56/f0/bea7e7909ea9045b0c5055dad1ec9b81c82b761b4567e625f4f8349acfa1/greenlet-3.2.5-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:edbf4ab9a7057ee430a678fe2ef37ea5d69125d6bdc7feb42ed8d871c737e63b", size = 640849, upload-time = "2026-02-20T20:43:57.305Z" }, { url = "https://files.pythonhosted.org/packages/0f/36/84630e9ff1dfc8b7690957c0f77834a84eabdbd9c4977c3a2d0cbd5325c2/greenlet-3.2.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc1d01bdd67db3e5711e6246e451d7a0f75fae7bbf40adde129296a7f9aa7cc9", size = 639841, upload-time = "2026-02-20T20:07:17.473Z" }, { url = "https://files.pythonhosted.org/packages/12/c4/6a2ee6c676dea7a05a3c3c1291fbc8ea44f26456b0accc891471293825af/greenlet-3.2.5-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bd593db7ee1fa8a513a48a404f8cc4126998a48025e3f5cbbc68d51be0a6bf66", size = 588813, upload-time = "2026-02-20T20:07:56.171Z" }, { url = "https://files.pythonhosted.org/packages/01/c0/75e75c2c993aa850292561ec80f5c263e3924e5843aa95a38716df69304c/greenlet-3.2.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ac8db07bced2c39b987bba13a3195f8157b0cfbce54488f86919321444a1cc3c", size = 1117377, upload-time = "2026-02-20T20:32:48.452Z" }, @@ -943,7 +851,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d8/7b/c6e1192c795c0c12871e199237909a6bd35757d92c8472c7c019959b8637/greenlet-3.2.5-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:acabf468466d18017e2ae5fbf1a5a88b86b48983e550e1ae1437b69a83d9f4ac", size = 276916, upload-time = "2026-02-20T20:06:18.166Z" }, { url = "https://files.pythonhosted.org/packages/3e/b6/9887b559f3e1952d23052ec352e9977e808a2246c7cb8282a38337221e88/greenlet-3.2.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:472841de62d60f2cafd60edd4fd4dd7253eb70e6eaf14b8990dcaf177f4af957", size = 636107, upload-time = "2026-02-20T20:30:43.362Z" }, { url = "https://files.pythonhosted.org/packages/8a/be/e3e48b63bbc27d660fa1d98aecb64906b90a12e686a436169c1330ef34b2/greenlet-3.2.5-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7d951e7d628a6e8b68af469f0fe4f100ef64c4054abeb9cdafbfaa30a920c950", size = 648240, upload-time = "2026-02-20T20:37:24.608Z" }, - { url = "https://files.pythonhosted.org/packages/17/f6/2cbe999683f759f14f598234f04ae8ba6f22953a624b3a7a630003e6bfff/greenlet-3.2.5-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:87b791dd0e031a574249af717ac36f7031b18c35329561c1e0368201c18caf1f", size = 644170, upload-time = "2026-02-20T20:43:59.002Z" }, { url = "https://files.pythonhosted.org/packages/4c/ac/e731ed62576e91e533b36d0d97325adc2786674ab9e48ed8a6a24f4ef4e9/greenlet-3.2.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8317d732e2ae0935d9ed2af2ea876fa714cf6f3b887a31ca150b54329b0a6e9", size = 643313, upload-time = "2026-02-20T20:07:19.012Z" }, { url = "https://files.pythonhosted.org/packages/70/64/99e5cdceb494bd4c1341c45b93f322601d2c8a5e1e4d1c7a2d24c5ed0570/greenlet-3.2.5-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce8aed6fdd5e07d3cbb988cbdc188266a4eb9e1a52db9ef5c6526e59962d3933", size = 591295, upload-time = "2026-02-20T20:07:57.286Z" }, { url = "https://files.pythonhosted.org/packages/ee/e9/968e11f388c2b8792d3b8b40a57984c894a3b4745dae3662dce722653bc5/greenlet-3.2.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:60c06b502d56d5451f60ca665691da29f79ed95e247bcf8ce5024d7bbe64acb9", size = 1120277, upload-time = "2026-02-20T20:32:50.103Z" }, @@ -951,7 +858,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ad/32/022b21523eee713e7550162d5ca6aed23f913cc2c6232b154b9fd9badc07/greenlet-3.2.5-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:2acb30e77042f747ca81f0a10cc153296567e92e666c5e1b117f4595afd43352", size = 278412, upload-time = "2026-02-20T20:03:15.02Z" }, { url = "https://files.pythonhosted.org/packages/90/c5/8a3b0ed3cc34d8b988a44349437dfa0941f9c23ac108175f7b4ccea97111/greenlet-3.2.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:393c03c26c865f17f31d8db2f09603fadbe0581ad85a5d5908b131549fc38217", size = 644616, upload-time = "2026-02-20T20:30:44.823Z" }, { url = "https://files.pythonhosted.org/packages/b1/2c/2627bea183554695016af6cae93d7474fa90f61e5a6601a84ae7841cb720/greenlet-3.2.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:04e6a202cde56043fd355fefd1552c4caa5c087528121871d950eb4f1b51fa99", size = 658813, upload-time = "2026-02-20T20:37:26.255Z" }, - { url = "https://files.pythonhosted.org/packages/44/c6/a80fc96f7cca7962dd972875d12c52dfabc94cb02bfeb19f3e7e169fca44/greenlet-3.2.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d5583b2ffa677578a384337ee13125bdf9a427485d689014b39d638a4f3d8dbe", size = 653512, upload-time = "2026-02-20T20:44:00.343Z" }, { url = "https://files.pythonhosted.org/packages/2f/1b/75a5aeff487a26ba427a3837da6372f1fe6f2a9c6b2898e28ac99d491c11/greenlet-3.2.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:45fcea7b697b91290b36eafc12fff479aca6ba6500d98ef6f34d5634c7119cbe", size = 655426, upload-time = "2026-02-20T20:07:20.124Z" }, { url = "https://files.pythonhosted.org/packages/53/91/9b5dfb4f3c88f8247c7a8f4c3759f0740bfa6bb0c59a9f6bf938e913df56/greenlet-3.2.5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f96e2bb8a56b7e1aed1dbfbbe0050cb2ecca99c7c91892fd1771e3afab63b3e3", size = 611138, upload-time = "2026-02-20T20:07:58.966Z" }, { url = "https://files.pythonhosted.org/packages/b4/8d/d0b086410512d9859c84e9242a9b341de9f5566011ddf3a3f6886b842b61/greenlet-3.2.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d7456e67b0be653dfe643bb37d9566cd30939c80f858e2ce6d2d54951f75b14a", size = 1126896, upload-time = "2026-02-20T20:32:52.198Z" }, @@ -959,7 +865,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dd/95/d5d332fb73affaf7a1fbe80e49c2c7eae4f17c645af24a3b3fa25736d6f0/greenlet-3.2.5-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:f2cc88b50b9006b324c1b9f5f3552f9d4564c78af57cdfb4c7baf4f0aa089146", size = 277166, upload-time = "2026-02-20T20:03:57.077Z" }, { url = "https://files.pythonhosted.org/packages/6c/77/89458e20db5a4f1c64f9a0191561227e76d809941ca2d7529006d17d3450/greenlet-3.2.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e66872daffa360b2537170b73ad530f14fa31785b1bc78080125d92edf0a6def", size = 644674, upload-time = "2026-02-20T20:30:46.118Z" }, { url = "https://files.pythonhosted.org/packages/90/f8/9962175d2f2eaa629a7fd7545abacc8c4deda3baa4e52c1526d2eb5f5546/greenlet-3.2.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c5445ddb7b586d870dad32ca9fc47c287d6022a528d194efdb8912093c5303ad", size = 658834, upload-time = "2026-02-20T20:37:27.466Z" }, - { url = "https://files.pythonhosted.org/packages/81/71/52c21a7106ce5218aa6fa59ec32825b2655f875a09b69f68bd3e5d01feb3/greenlet-3.2.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:fd904626b8779810062cb455514594776e3cba3b8c0ba4939894df9f7b384971", size = 653091, upload-time = "2026-02-20T20:44:01.927Z" }, { url = "https://files.pythonhosted.org/packages/f5/d7/826d0e080f0a7ad5ec47c8d143bbd3ca0887657bb806595fe2434d12938a/greenlet-3.2.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:752c896a8c976548faafe8a306d446c6a4c68d4fd24699b84d4393bd9ac69a8e", size = 655760, upload-time = "2026-02-20T20:07:21.551Z" }, { url = "https://files.pythonhosted.org/packages/41/cc/33bd4c2f816be8c8e16f71740c4130adf3a66a3dd2ba29de72b9d8dd1096/greenlet-3.2.5-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:499b809e7738c8af0ff9ac9d5dd821cb93f4293065a9237543217f0b252f950a", size = 614132, upload-time = "2026-02-20T20:08:00.351Z" }, { url = "https://files.pythonhosted.org/packages/48/79/f3891dcfc59097474a53cc3c624f2f2465e431ab493bda043b8c873fb20a/greenlet-3.2.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:2c7429f6e9cea7cbf2637d86d3db12806ba970f7f972fcab39d6b54b4457cbaf", size = 1125286, upload-time = "2026-02-20T20:32:54.032Z" }, @@ -967,7 +872,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/9d/4e9b941be05f8da7ba804c6413761d2c11cca05994cbf0a015bd729419f0/greenlet-3.2.5-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:7123b29e6bad2f3f89681be4ef316480fca798ebe8d22fbaced9cc3775007a4f", size = 277627, upload-time = "2026-02-20T20:06:04.798Z" }, { url = "https://files.pythonhosted.org/packages/23/cb/a73625c9a35138330014ecf3740c0d62e0c2b5e7279bb7f2586b1b199fac/greenlet-3.2.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6e8fe0c72603201a86b2e038daf9b6c8570715f8779566419cff543b6ace88de", size = 690001, upload-time = "2026-02-20T20:30:47.754Z" }, { url = "https://files.pythonhosted.org/packages/83/49/6d1531109507bce7dfb23acf57a87013627ed3ac058851176e443a6a9134/greenlet-3.2.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:050703a60603db0e817364d69e048c70af299040c13a7e67792b9e62d4571196", size = 702953, upload-time = "2026-02-20T20:37:29.125Z" }, - { url = "https://files.pythonhosted.org/packages/90/ac/6d8fff3b273fc60ad4b46f8411fe91c1e4cca064dfff68d096bc982fa6d0/greenlet-3.2.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:04633da773ae432649a3f092a8e4add390732cc9e1ab52c8ff2c91b8dc86f202", size = 698353, upload-time = "2026-02-20T20:44:03.547Z" }, { url = "https://files.pythonhosted.org/packages/f7/38/f958ee90fab93529b30cc1e4a59b27c1112b640570043a84af84da3b3b98/greenlet-3.2.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6712bfd520530eb67331813f7112d3ee18e206f48b3d026d8a96cd2d2ad20251", size = 698995, upload-time = "2026-02-20T20:07:22.663Z" }, { url = "https://files.pythonhosted.org/packages/51/c1/a603906e79716d61f08afedaf8aed62017661457aef233d62d6e57ecd511/greenlet-3.2.5-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bc06a78fa3ffbe2a75f1ebc7e040eacf6fa1050a9432953ab111fbbbf0d03c1", size = 661175, upload-time = "2026-02-20T20:08:01.477Z" }, ] @@ -1836,41 +1740,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, ] -[[package]] -name = "networkx" -version = "3.4.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11'", -] -sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" }, -] - -[[package]] -name = "networkx" -version = "3.6.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version == '3.13.*' and sys_platform == 'win32'", - "python_full_version == '3.13.*' and sys_platform == 'emscripten'", - "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version == '3.12.*' and sys_platform == 'win32'", - "python_full_version == '3.12.*' and sys_platform == 'emscripten'", - "python_full_version == '3.12.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version == '3.11.*' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and sys_platform == 'emscripten'", - "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, -] - [[package]] name = "nh3" version = "0.3.5" @@ -2009,155 +1878,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl", hash = "sha256:3149da9874af890bcc2a82ef7aae5484e5aa81cb2778f08e3c307ba6d963721b", size = 69255, upload-time = "2025-12-02T16:39:11.561Z" }, ] -[[package]] -name = "nvidia-cublas" -version = "13.1.0.3" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/a5/fce49e2ae977e0ccc084e5adafceb4f0ac0c8333cb6863501618a7277f67/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2", size = 542851226, upload-time = "2025-10-09T08:59:04.818Z" }, - { url = "https://files.pythonhosted.org/packages/e7/44/423ac00af4dd95a5aeb27207e2c0d9b7118702149bf4704c3ddb55bb7429/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171", size = 423133236, upload-time = "2025-10-09T08:59:32.536Z" }, -] - -[[package]] -name = "nvidia-cuda-cupti" -version = "13.0.85" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/2a/80353b103fc20ce05ef51e928daed4b6015db4aaa9162ed0997090fe2250/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_aarch64.whl", hash = "sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151", size = 10310827, upload-time = "2025-09-04T08:26:42.012Z" }, - { url = "https://files.pythonhosted.org/packages/33/6d/737d164b4837a9bbd202f5ae3078975f0525a55730fe871d8ed4e3b952b0/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_x86_64.whl", hash = "sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8", size = 10715597, upload-time = "2025-09-04T08:26:51.312Z" }, -] - -[[package]] -name = "nvidia-cuda-nvrtc" -version = "13.0.88" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/68/483a78f5e8f31b08fb1bb671559968c0ca3a065ac7acabfc7cee55214fd6/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575", size = 90215200, upload-time = "2025-09-04T08:28:44.204Z" }, - { url = "https://files.pythonhosted.org/packages/b7/dc/6bb80850e0b7edd6588d560758f17e0550893a1feaf436807d64d2da040f/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b", size = 43015449, upload-time = "2025-09-04T08:28:20.239Z" }, -] - -[[package]] -name = "nvidia-cuda-runtime" -version = "13.0.96" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/4f/17d7b9b8e285199c58ce28e31b5c5bbaa4d8271af06a89b6405258245de2/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55", size = 2261060, upload-time = "2025-10-09T08:55:15.78Z" }, - { url = "https://files.pythonhosted.org/packages/2e/24/d1558f3b68b1d26e706813b1d10aa1d785e4698c425af8db8edc3dced472/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548", size = 2243632, upload-time = "2025-10-09T08:55:36.117Z" }, -] - -[[package]] -name = "nvidia-cudnn-cu13" -version = "9.20.0.48" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-cublas", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/56/c5/83384d846b2fd17c44bd499b36c75a45ed4f095fbbb2252294e89cea5c5c/nvidia_cudnn_cu13-9.20.0.48-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:e31454ae00094b0c55319d9d15b6fa2fc50a9e1c0f5c8c80fb75258234e731e1", size = 444574296, upload-time = "2026-03-09T19:28:27.751Z" }, - { url = "https://files.pythonhosted.org/packages/6e/5e/edb9c0ae051602c3ccaffe424256463636d639e27d7f302dde9975ef9e7a/nvidia_cudnn_cu13-9.20.0.48-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:0c45dd8eeb50b603f07995b1b300c62ffe6a1980482b82b3bcf94a4ca9d49304", size = 366173588, upload-time = "2026-03-09T19:29:34.474Z" }, -] - -[[package]] -name = "nvidia-cufft" -version = "12.0.0.61" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" }, - { url = "https://files.pythonhosted.org/packages/a8/2f/7b57e29836ea8714f81e9898409196f47d772d5ddedddf1592eadb8ab743/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3", size = 214085489, upload-time = "2025-09-04T08:31:56.044Z" }, -] - -[[package]] -name = "nvidia-cufile" -version = "1.15.1.6" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/70/4f193de89a48b71714e74602ee14d04e4019ad36a5a9f20c425776e72cd6/nvidia_cufile-1.15.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08a3ecefae5a01c7f5117351c64f17c7c62efa5fffdbe24fc7d298da19cd0b44", size = 1223672, upload-time = "2025-09-04T08:32:22.779Z" }, - { url = "https://files.pythonhosted.org/packages/ab/73/cc4a14c9813a8a0d509417cf5f4bdaba76e924d58beb9864f5a7baceefbf/nvidia_cufile-1.15.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:bdc0deedc61f548bddf7733bdc216456c2fdb101d020e1ab4b88d232d5e2f6d1", size = 1136992, upload-time = "2025-09-04T08:32:14.119Z" }, -] - -[[package]] -name = "nvidia-curand" -version = "10.4.0.35" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/72/7c2ae24fb6b63a32e6ae5d241cc65263ea18d08802aaae087d9f013335a2/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:133df5a7509c3e292aaa2b477afd0194f06ce4ea24d714d616ff36439cee349a", size = 61962106, upload-time = "2025-08-04T10:21:41.128Z" }, - { url = "https://files.pythonhosted.org/packages/a5/9f/be0a41ca4a4917abf5cb9ae0daff1a6060cc5de950aec0396de9f3b52bc5/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:1aee33a5da6e1db083fe2b90082def8915f30f3248d5896bcec36a579d941bfc", size = 59544258, upload-time = "2025-08-04T10:22:03.992Z" }, -] - -[[package]] -name = "nvidia-cusolver" -version = "12.0.4.66" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-cublas", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, - { name = "nvidia-cusparse", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, - { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" }, - { url = "https://files.pythonhosted.org/packages/5f/67/cba3777620cdacb99102da4042883709c41c709f4b6323c10781a9c3aa34/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112", size = 200941980, upload-time = "2025-09-04T08:33:22.767Z" }, -] - -[[package]] -name = "nvidia-cusparse" -version = "12.6.3.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" }, - { url = "https://files.pythonhosted.org/packages/fa/18/623c77619c31d62efd55302939756966f3ecc8d724a14dab2b75f1508850/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b", size = 145942937, upload-time = "2025-09-04T08:33:58.029Z" }, -] - -[[package]] -name = "nvidia-cusparselt-cu13" -version = "0.8.1" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/46/e1/cdc1797eadf82d3a9a575a19b33fdc871a97edbec42c00b5b5e914f4aff4/nvidia_cusparselt_cu13-0.8.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4dca476c50bf4780d46cd0bfbd82e2bc10a08e4fef7950917ce8d7578d22a23f", size = 221051344, upload-time = "2025-09-05T18:49:51.289Z" }, - { url = "https://files.pythonhosted.org/packages/34/7d/2661f2fb3ac4302f3a246f5fc030213ac60c1fe0bce84f9783dbd831dbb7/nvidia_cusparselt_cu13-0.8.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:786ce87568c303fadb5afcc7102d454cd3040d75f6f8626f5db460d1871f4dd0", size = 170148586, upload-time = "2025-09-05T18:50:50.248Z" }, -] - -[[package]] -name = "nvidia-nccl-cu13" -version = "2.29.7" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/72/0d/daf50d44177ee0cbc7ff0a0c91eb5ff676c82be42f9a970bc7597f440c3a/nvidia_nccl_cu13-2.29.7-py3-none-manylinux_2_18_aarch64.whl", hash = "sha256:674a12383e3c38a1bcccae7d4f3633b37852230b6047883cb2f4c2d1b36d9bf5", size = 206014712, upload-time = "2026-03-03T05:34:20.843Z" }, - { url = "https://files.pythonhosted.org/packages/67/f4/58e4e91b6919367c7aafb8e36fce9aad1a3047e536bf7e2fd560927d3a4c/nvidia_nccl_cu13-2.29.7-py3-none-manylinux_2_18_x86_64.whl", hash = "sha256:edd81538446786ec3b73972543e53bb43bcaf0bfc8ef76cb679fcc390ffe136d", size = 205976000, upload-time = "2026-03-03T05:36:24.472Z" }, -] - -[[package]] -name = "nvidia-nvjitlink" -version = "13.0.88" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/56/7a/123e033aaff487c77107195fa5a2b8686795ca537935a24efae476c41f05/nvidia_nvjitlink-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b", size = 40713933, upload-time = "2025-09-04T08:35:43.553Z" }, - { url = "https://files.pythonhosted.org/packages/ab/2c/93c5250e64df4f894f1cbb397c6fd71f79813f9fd79d7cd61de3f97b3c2d/nvidia_nvjitlink-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c", size = 38768748, upload-time = "2025-09-04T08:35:20.008Z" }, -] - -[[package]] -name = "nvidia-nvshmem-cu13" -version = "3.4.5" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/0f/05cc9c720236dcd2db9c1ab97fff629e96821be2e63103569da0c9b72f19/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9", size = 60215947, upload-time = "2025-09-06T00:32:20.022Z" }, - { url = "https://files.pythonhosted.org/packages/3c/35/a9bf80a609e74e3b000fef598933235c908fcefcef9026042b8e6dfde2a9/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80", size = 60412546, upload-time = "2025-09-06T00:32:41.564Z" }, -] - -[[package]] -name = "nvidia-nvtx" -version = "13.0.85" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/f3/d86c845465a2723ad7e1e5c36dcd75ddb82898b3f53be47ebd429fb2fa5d/nvidia_nvtx-13.0.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4", size = 148047, upload-time = "2025-09-04T08:29:01.761Z" }, - { url = "https://files.pythonhosted.org/packages/a8/64/3708a90d1ebe202ffdeb7185f878a3c84d15c2b2c31858da2ce0583e2def/nvidia_nvtx-13.0.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6", size = 148878, upload-time = "2025-09-04T08:28:53.627Z" }, -] - [[package]] name = "oct2py" version = "5.8.0" @@ -3560,59 +3280,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/61/cceae43728b7de99d9b847560c262873a1f6c98202171fd5ed62640b494b/tomli-2.4.1-py3-none-any.whl", hash = "sha256:0d85819802132122da43cb86656f8d1f8c6587d54ae7dcaf30e90533028b49fe", size = 14583, upload-time = "2026-03-25T20:22:03.012Z" }, ] -[[package]] -name = "torch" -version = "2.12.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, - { name = "cuda-toolkit", extra = ["cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux'" }, - { name = "filelock" }, - { name = "fsspec" }, - { name = "jinja2" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "nvidia-cublas", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu13", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu13", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu13", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvshmem-cu13", marker = "sys_platform == 'linux'" }, - { name = "setuptools" }, - { name = "sympy" }, - { name = "triton", marker = "sys_platform == 'linux'" }, - { name = "typing-extensions" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/b7/53fe0436586716ab7aecff41e26b9302d57c85ded481fd83a2cd741e6b4e/torch-2.12.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1834bd984f8a2f4f16bdfbeecca9146184b220aa46276bf5756735b5dae12812", size = 87981887, upload-time = "2026-05-13T14:55:53.234Z" }, - { url = "https://files.pythonhosted.org/packages/34/60/d930eac44c30de06ed16f6d1ba4e785e1632532b50d8f0bf9bf699a4d0c7/torch-2.12.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d4d029801cb7b6df858804a2a21b00cc2aa0bf0ee5d2ab18d343c9e9e5681f35", size = 426355000, upload-time = "2026-05-13T14:54:31.944Z" }, - { url = "https://files.pythonhosted.org/packages/8e/0c/c76b6a087820bab55705b94dfc074e520de9ae91f5ef90da2ecbf2a3ef12/torch-2.12.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:d47e7dee68ac4cd7a068b26bcd6b989935427709fae1c8f7bd0019978f829e15", size = 532144998, upload-time = "2026-05-13T14:56:05.523Z" }, - { url = "https://files.pythonhosted.org/packages/4a/64/8a0d036e166a6aa85ee09bef072f3655d1ba5d5486a68d1b03b6813c01b3/torch-2.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:cf9839790285dd472e7a16aafcb4a4e6bf58ec1b494045044b0eefb0eb4bd1f2", size = 122949877, upload-time = "2026-05-13T14:55:46.841Z" }, - { url = "https://files.pythonhosted.org/packages/18/62/131124fb95df03811b8260d1d43dcc5ee85ea1a344b964613d7efe77fb08/torch-2.12.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:10802fd383bbfed646212e765a72c37d2185205d4f26eb197a254e8ac7ddcb25", size = 87990344, upload-time = "2026-05-13T14:55:42.154Z" }, - { url = "https://files.pythonhosted.org/packages/12/9c/dda0dbd547dc549839824135f223792fd0e725f28ed0715dda366b7acaa2/torch-2.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c12592630aef72feaf18bd3f197ef587bbfa21131b31c38b23ab2e55fce92e36", size = 426362932, upload-time = "2026-05-13T14:54:15.295Z" }, - { url = "https://files.pythonhosted.org/packages/e2/d2/a7dd5a3f9bdaa7842124e8e2359202b317c48d47d2fc5816fafdf2049adb/torch-2.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:415c1b8d0412f67551c8e89a2daca0fb3e56694af0281ba155eaa9da481f58b4", size = 532170085, upload-time = "2026-05-13T14:55:20.788Z" }, - { url = "https://files.pythonhosted.org/packages/12/1b/a61ce2004f9ab0ea8964a6e6168133a127795667639e2ff4f8f2bdb16a65/torch-2.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:dd37188ea325042cb1f6cafa56822b11ada2520c04791a52629b0af25bdfbfd9", size = 122953128, upload-time = "2026-05-13T14:54:52.744Z" }, - { url = "https://files.pythonhosted.org/packages/ef/bb/285d643f254731294c9b595a007eac39db4600a98682d7bca688f42ca164/torch-2.12.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:b41339df93d491435e790ff8bcbae1c0ce777175889bfd1281d119862793e6a2", size = 88010197, upload-time = "2026-05-13T14:55:35.414Z" }, - { url = "https://files.pythonhosted.org/packages/79/81/76debf1db1343bd929bbb5d74c89fb437c2ed88eb144712557e7bd3eea45/torch-2.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8fbef9f108a863e7722a73740998967e3b074742a834fc5be3a535a2befa7057", size = 426376751, upload-time = "2026-05-13T14:55:03.353Z" }, - { url = "https://files.pythonhosted.org/packages/de/f0/80026028b603c4650ff270fc3785bdef4bd6738765a9cc5a0f5a637d65a2/torch-2.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4b4f64c2c2b11f7510d93dd6412b87025ff6eddd6bb61c3b5a3d892ea20c4756", size = 532261691, upload-time = "2026-05-13T14:52:54.453Z" }, - { url = "https://files.pythonhosted.org/packages/b9/c2/64b06cbb7830fb3cd9be13e1158b31a3f36b68e6a209105ee3c9d9480be0/torch-2.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:8b958caff4a14d3a3b0b2dfc6a378f64dda9728a9dad28c08a0db9ce4dafb549", size = 122988114, upload-time = "2026-05-13T14:54:42.153Z" }, - { url = "https://files.pythonhosted.org/packages/86/ca/01896c80ba921676aa45886b2c5b8d774912de2a1f719de48169c6f755cd/torch-2.12.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:90dd587a5f61bfe1307148b581e2084fc5bc4a06e2b90a20e9a36b81087ff16b", size = 88009511, upload-time = "2026-05-13T14:54:47.411Z" }, - { url = "https://files.pythonhosted.org/packages/a5/04/52bdaf4787eab6ac7d7f5851dff934e4def0bc8ead9c8fd2b69b3e529699/torch-2.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:864392c73b7654f4d2b3ae712f607937d0dbb1101c4555fbb41848106b297f39", size = 426383231, upload-time = "2026-05-13T14:53:32.129Z" }, - { url = "https://files.pythonhosted.org/packages/49/8a/94bdecd13f5aaa90d45920b89789d9fe7c6f4af8c3cdd7ce01fcb59908fc/torch-2.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:5d6b560dfa7d56291c07d615c3bb73e8d9943d9b6d87f76cd0d9d570c4797fa6", size = 532269288, upload-time = "2026-05-13T14:53:49.423Z" }, - { url = "https://files.pythonhosted.org/packages/3e/2f/bdbaaa267de519ef1b73054bf590d8c93c37a266c9a4e24a01bd38b6918f/torch-2.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:3fee918902090ade827643e758e98363278815de583c75d111fdd665ebffde9f", size = 122987706, upload-time = "2026-05-13T14:54:00.335Z" }, - { url = "https://files.pythonhosted.org/packages/9b/ad/e95e822f3538171e22640a7fbe839a1fdb666600bf6487025de2ff03b11a/torch-2.12.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:10ee1448a9f304d3b987eb4656f664ba6e4d7b410ca7a5a7c642199777a2cf88", size = 88319556, upload-time = "2026-05-13T14:54:05.574Z" }, - { url = "https://files.pythonhosted.org/packages/b7/07/055d06d985b445d67422d25b033c11cf55bbb81785d4c4e68e28bca5820e/torch-2.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:af68dbf403439cae9ceaeaaf92f8352b460787dcd27b92aa05c40dd4a19c0f1e", size = 426397656, upload-time = "2026-05-13T14:52:38.84Z" }, - { url = "https://files.pythonhosted.org/packages/43/94/b0b4fdc3014122e0a7302fb90086d352aa48f2576f0b252561ebb38c01a8/torch-2.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:a6a2eebb237d3b1d9ad3b378e86d9b9e0782afdea8b1e0eba6a13646b9b49c07", size = 532183124, upload-time = "2026-05-13T14:53:16.178Z" }, - { url = "https://files.pythonhosted.org/packages/d8/c8/052405e6ad05d3237bfe5a4df78f917773956f8e17813a2d44c059068b74/torch-2.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:2140e373e9a51a3e22ef62e8d14366d0b470d18f0adf19fdc757368077133a34", size = 123232462, upload-time = "2026-05-13T14:52:27.26Z" }, - { url = "https://files.pythonhosted.org/packages/67/dc/ac069f8d6e8be701535921141055293b0d4819d3d7f224a4612cf157c7f9/torch-2.12.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f7dfae4a519197dfa050e98d8e36378a0fb5899625a875c2b54445005a2e404e", size = 88027282, upload-time = "2026-05-13T14:53:05.258Z" }, - { url = "https://files.pythonhosted.org/packages/33/c3/1c1eb00e34555b536dddf792676026a988d710ed36981aa00499b36b0620/torch-2.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:891c769072637c74e9a5a77a3bc782894696d8ffec83b938df8536dee7f0ba78", size = 426386961, upload-time = "2026-05-13T14:51:28.406Z" }, - { url = "https://files.pythonhosted.org/packages/cd/d4/7e730dba0c7032a4154dc9056b76cf9625515e030e269cfbf8098fcfee7d/torch-2.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:e2ad3eb85d39c3cab62dfa93ed5a73516e6a53c6713cb97d004004fe089f0f1f", size = 532272265, upload-time = "2026-05-13T14:51:59.308Z" }, - { url = "https://files.pythonhosted.org/packages/f1/b4/92c80d1bbfee1c0036c06d1d2155a3065bd2423134c83bf8a47e65cd6b9b/torch-2.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:c66696857e987efb8bc1777a37357ec4f60ab5e8af6250b83d6034437fa2d8f3", size = 122987138, upload-time = "2026-05-13T14:51:45.942Z" }, - { url = "https://files.pythonhosted.org/packages/7b/78/2e12b37ce50a19a037d7bc62d652a5a8f27385a7b05859d6bc9204f20cfe/torch-2.12.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:b4556715c8572758625d62b6e0ae3b1f76c440221913a6fb5e100f321fb4fb02", size = 88320100, upload-time = "2026-05-13T14:51:39.955Z" }, - { url = "https://files.pythonhosted.org/packages/56/5e/83c450ec7b0bb40a7b74611c1b5440f9260e33c54c90d556fd4a1f0fd955/torch-2.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a43ac605a5e13116c72b64c359644cce0229f213dde48d2ae0ae5eb5becf7feb", size = 426391871, upload-time = "2026-05-13T14:52:14.989Z" }, - { url = "https://files.pythonhosted.org/packages/c9/e9/1a0b575d98d0afedd8f157d23fa3d2759421483660448e60d0a4b10b6daa/torch-2.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6a7512adfdd7f6732e40de1c620831e3c75b39b98cef60b11d0c5f0a76473ec5", size = 532192241, upload-time = "2026-05-13T14:51:07.795Z" }, - { url = "https://files.pythonhosted.org/packages/88/21/afadd25ecd81b3cea1e11c73cf1ab41a983a50271548c3ec7ec3b9efc3e9/torch-2.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:5f96b63f8287f66a005dd1b5a6abba2920f11156c5e5c4d815f3e2050fd1aa16", size = 123231092, upload-time = "2026-05-13T14:51:18.854Z" }, -] - [[package]] name = "tornado" version = "6.5.5" @@ -3651,27 +3318,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, ] -[[package]] -name = "triton" -version = "3.7.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/97/dcd1f2a0f8336691bff74abc59b2ed9c69a0c0f8f65cd77109c49e05f068/triton-3.7.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223ac302091491436c248a34ee1e6c47a1026486579103c906ffd805be50cb89", size = 188367104, upload-time = "2026-05-07T19:04:56.68Z" }, - { url = "https://files.pythonhosted.org/packages/b2/c0/c2ac4fd2d8809b7579d4a820a0f9e5de62a9bc8a757ed4b3abf4f7ee964a/triton-3.7.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c631b65668d4951213b948a413c0564184305b77bb45cc9d686d3e1ecc4701a3", size = 201313191, upload-time = "2026-05-07T18:45:58.444Z" }, - { url = "https://files.pythonhosted.org/packages/b8/c1/5d842314bb6c78442cc60437928781701c6050b8d479bc2a1aed691d37ca/triton-3.7.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a9e71fc392675fac364e0ecf4ef3f76f85b7f5433a16f4c3c5fe5f05a52c85fe", size = 188480277, upload-time = "2026-05-07T19:05:03.231Z" }, - { url = "https://files.pythonhosted.org/packages/13/31/8315ea5f8dd18e60970b3022e3a8b93fd37e0b784fbbef86e10c8e6e5ca1/triton-3.7.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22bacffce443f54593dd20f05294d5a40622e0ea9ab632816f87154504356221", size = 201415942, upload-time = "2026-05-07T18:46:06.479Z" }, - { url = "https://files.pythonhosted.org/packages/f7/13/ec05adfcd87311d532ba61e3af143e8be59fcd26675884c4682841406a20/triton-3.7.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a4bf49b00a7a377a68a6da603a876e797614e6455a80e9021669c476a953ad9a", size = 188505104, upload-time = "2026-05-07T19:05:09.843Z" }, - { url = "https://files.pythonhosted.org/packages/62/7b/468a576e35beef1426e0828e28e9ba9e65f5474d496f16ee126c15646324/triton-3.7.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8f111161d49bf903c0eaedde3962353a3d841c08a836839b7cc1025b8426efcf", size = 201457567, upload-time = "2026-05-07T18:46:13.505Z" }, - { url = "https://files.pythonhosted.org/packages/01/e1/a59a583de59b8f62c495d67c80ee3ea97d09e91ac80c4c6e76456ed8d8ac/triton-3.7.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:abdf6beaa89b1bcfb9a43cd990536ce66091a997841a4814b260b7bee4c88c3c", size = 188503209, upload-time = "2026-05-07T19:05:17.935Z" }, - { url = "https://files.pythonhosted.org/packages/30/b1/b7507bb9815d403927c8dd51d4158ed2e11751a92dbc118a044f247b6848/triton-3.7.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a35d7afe3f3f058e7ec49fcce09794049e0ffc5c59019ac25ec3413741b8c4e7", size = 201453566, upload-time = "2026-05-07T18:46:20.427Z" }, - { url = "https://files.pythonhosted.org/packages/a6/8f/0bea7a6a0c989315c9135a1d7fb37e41905cfb3a17cbc1f10044ebd4cc3a/triton-3.7.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc1d61c172d257db80ddf42595131fb196ad2e9bdd751e90fe2ef13531734e8b", size = 188612899, upload-time = "2026-05-07T19:05:24.955Z" }, - { url = "https://files.pythonhosted.org/packages/e1/02/d96f57828d0912aec733b9bc7e0e7dbfd2c6f079a8fa433ac25cb93d1a30/triton-3.7.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70fb9bbdc9f400afc54bbf6eb2670af28829a6ae3996863317964783141daf56", size = 201553816, upload-time = "2026-05-07T18:46:27.49Z" }, - { url = "https://files.pythonhosted.org/packages/40/fb/82a802dac4689f2a2fb2e69302e6a138eecc3e175bbe976ba3cfc717683a/triton-3.7.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4a44a8476d0d3571eac4e4d1048e1ff75aad81a09ff4602ccfc56c6dea1672e", size = 188507879, upload-time = "2026-05-07T19:05:32.209Z" }, - { url = "https://files.pythonhosted.org/packages/8f/af/9904ec6d3c93d9b24e5ec360445bbdf758b7f00bfbeedb89cb0eb64eb8bb/triton-3.7.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b9b85e72968a9d8bba5ddb24e9b64aaabaf48affb042f2755cb7cfa92b7531ce", size = 201460637, upload-time = "2026-05-07T18:46:34.749Z" }, - { url = "https://files.pythonhosted.org/packages/a1/f9/4835a8ea746b88727d8899f4e3ccce4f9cacb38abfc3bb0a638266c53111/triton-3.7.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18a160de426fd99f92b0baf509045360afbd3bfaa0b4a5171dde800ec9f09684", size = 188608706, upload-time = "2026-05-07T19:05:39.218Z" }, - { url = "https://files.pythonhosted.org/packages/c1/68/fa86e5a39608000f645535b2c124920126327ab731f8c4fafd5b07ff8d4b/triton-3.7.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce061073102714b725f3660ec6939d94a1da7984b3aa99c921417cae273672f5", size = 201546766, upload-time = "2026-05-07T18:46:42.088Z" }, -] - [[package]] name = "twine" version = "6.2.0"