From 791be81914bfde7ced846c5942f0f1d0e45c71b8 Mon Sep 17 00:00:00 2001 From: Jules Date: Fri, 26 Jun 2026 00:37:01 +0000 Subject: [PATCH] feat: expand AST detection for in-place EEG edits and inject plotting syncs --- src/eegprep/functions/adminfunc/console.py | 55 ++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/src/eegprep/functions/adminfunc/console.py b/src/eegprep/functions/adminfunc/console.py index 032d3b04..9eac85f6 100644 --- a/src/eegprep/functions/adminfunc/console.py +++ b/src/eegprep/functions/adminfunc/console.py @@ -30,6 +30,7 @@ _EEG_CORE_FIELDS = ("nbchan", "srate", "pnts", "trials") _ACTIVE_TERMINAL_BUFFER = contextvars.ContextVar("_ACTIVE_TERMINAL_BUFFER", default=None) _MATLAB_MULTI_ASSIGN_PATTERN = re.compile(r"^\s*\[([A-Za-z_][A-Za-z0-9_]*(?:\s+[A-Za-z_][A-Za-z0-9_]*)+)\]\s*=") +_MUTATING_METHODS = {"pop", "update", "clear", "setdefault", "insert", "remove", "append", "extend", "fill"} _TUPLE_ASSIGNMENT_TARGET_PATTERN = re.compile( r"(^|;\s*)\(([A-Za-z_][A-Za-z0-9_]*(?:,\s*[A-Za-z_][A-Za-z0-9_]*)+)\)\s*=" ) @@ -339,6 +340,14 @@ def close(self) -> None: self.session.remove_gui_action_listener(self._gui_action_event) self._finish_gui_action_output() + def sync_state(self) -> None: + """Force the workspace to push its state to the session and refresh the GUI.""" + eeg = self.namespace.get("EEG") + if _is_eeg_selection(eeg): + self._store_eeg(eeg, "") + self.pull_from_session() + self._refresh() + def pull_from_session(self) -> None: """Mirror session state into the console namespace.""" self.namespace["eegprep"] = self._eegprep_proxy @@ -350,6 +359,7 @@ def pull_from_session(self) -> None: self.namespace["LASTCOM"] = self.session.LASTCOM self.namespace["STUDY"] = self.session.STUDY self.namespace["CURRENTSTUDY"] = self.session.CURRENTSTUDY + self.namespace["refresh"] = self.sync_state def after_execute(self, source: str, *, success: bool = True) -> None: """Push console-side workspace edits back into the session.""" @@ -661,6 +671,8 @@ def post_run_cell(result: Any) -> None: _safe_after_execute(self.workspace, raw_cell, success=success, write=sys.stderr.write) self.shell.events.register("post_run_cell", post_run_cell) + if hasattr(self.shell, "ast_transformers"): + self.shell.ast_transformers.append(_PlotSyncInjector()) try: self.shell() finally: @@ -937,6 +949,43 @@ def _normalise_tuple_assignment_targets(text: str) -> str: return _TUPLE_ASSIGNMENT_TARGET_PATTERN.sub(lambda match: f"{match.group(1)}{match.group(2)} =", text) +class _PlotSyncInjector(ast.NodeTransformer): + def visit_Expr(self, node: ast.Expr) -> Any: + self.generic_visit(node) + return self._inject_if_needed(node) + + def visit_Assign(self, node: ast.Assign) -> Any: + self.generic_visit(node) + return self._inject_if_needed(node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: + self.generic_visit(node) + return self._inject_if_needed(node) + + def visit_AugAssign(self, node: ast.AugAssign) -> Any: + self.generic_visit(node) + return self._inject_if_needed(node) + + def _inject_if_needed(self, node: ast.stmt) -> list[ast.stmt] | ast.stmt: + has_ui_call = False + for child in ast.walk(node): + if isinstance(child, ast.Call) and isinstance(child.func, ast.Name): + if child.func.id.startswith("pop_") or child.func.id in {"eegplot", "eeg_browser", "eegbrowser"}: + has_ui_call = True + break + if has_ui_call: + refresh_call = ast.Expr( + value=ast.Call( + func=ast.Name(id="refresh", ctx=ast.Load()), + args=[], + keywords=[], + ) + ) + ast.copy_location(refresh_call, node) + return [refresh_call, node] + return node + + class _ConsoleCommandArgumentConverter(ast.NodeTransformer): def __init__(self) -> None: self.changed = False @@ -1415,6 +1464,12 @@ def _workspace_assignment_targets(source: str) -> set[str]: raw_targets = node.targets if isinstance(node, ast.Assign) else [node.target] for target in raw_targets: targets.update(root for root in _target_root_names(target) if root in WORKSPACE_NAMES) + elif isinstance(node, ast.Delete): + for target in node.targets: + targets.update(root for root in _target_root_names(target) if root in WORKSPACE_NAMES) + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: + targets.update(root for root in _target_root_names(node.func.value) if root in WORKSPACE_NAMES) return targets