Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/eegprep/functions/adminfunc/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_EEG_CORE_FIELDS = ("nbchan", "srate", "pnts", "trials")
_ACTIVE_TERMINAL_BUFFER = contextvars.ContextVar("_ACTIVE_TERMINAL_BUFFER", default=None)
_MATLAB_MULTI_ASSIGN_PATTERN = re.compile(r"^\s*\[([A-Za-z_][A-Za-z0-9_]*(?:\s+[A-Za-z_][A-Za-z0-9_]*)+)\]\s*=")
_MUTATING_METHODS = {"pop", "update", "clear", "setdefault", "insert", "remove", "append", "extend", "fill"}
_TUPLE_ASSIGNMENT_TARGET_PATTERN = re.compile(
r"(^|;\s*)\(([A-Za-z_][A-Za-z0-9_]*(?:,\s*[A-Za-z_][A-Za-z0-9_]*)+)\)\s*="
)
Expand Down Expand Up @@ -339,6 +340,14 @@ def close(self) -> None:
self.session.remove_gui_action_listener(self._gui_action_event)
self._finish_gui_action_output()

def sync_state(self) -> None:
"""Force the workspace to push its state to the session and refresh the GUI."""
eeg = self.namespace.get("EEG")
if _is_eeg_selection(eeg):
self._store_eeg(eeg, "")
self.pull_from_session()
self._refresh()

def pull_from_session(self) -> None:
"""Mirror session state into the console namespace."""
self.namespace["eegprep"] = self._eegprep_proxy
Expand All @@ -350,6 +359,7 @@ def pull_from_session(self) -> None:
self.namespace["LASTCOM"] = self.session.LASTCOM
self.namespace["STUDY"] = self.session.STUDY
self.namespace["CURRENTSTUDY"] = self.session.CURRENTSTUDY
self.namespace["refresh"] = self.sync_state

def after_execute(self, source: str, *, success: bool = True) -> None:
"""Push console-side workspace edits back into the session."""
Expand Down Expand Up @@ -661,6 +671,8 @@ def post_run_cell(result: Any) -> None:
_safe_after_execute(self.workspace, raw_cell, success=success, write=sys.stderr.write)

self.shell.events.register("post_run_cell", post_run_cell)
if hasattr(self.shell, "ast_transformers"):
self.shell.ast_transformers.append(_PlotSyncInjector())
try:
self.shell()
finally:
Expand Down Expand Up @@ -937,6 +949,43 @@ def _normalise_tuple_assignment_targets(text: str) -> str:
return _TUPLE_ASSIGNMENT_TARGET_PATTERN.sub(lambda match: f"{match.group(1)}{match.group(2)} =", text)


class _PlotSyncInjector(ast.NodeTransformer):
def visit_Expr(self, node: ast.Expr) -> Any:
self.generic_visit(node)
return self._inject_if_needed(node)

def visit_Assign(self, node: ast.Assign) -> Any:
self.generic_visit(node)
return self._inject_if_needed(node)

def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
self.generic_visit(node)
return self._inject_if_needed(node)

def visit_AugAssign(self, node: ast.AugAssign) -> Any:
self.generic_visit(node)
return self._inject_if_needed(node)

def _inject_if_needed(self, node: ast.stmt) -> list[ast.stmt] | ast.stmt:
has_ui_call = False
for child in ast.walk(node):
if isinstance(child, ast.Call) and isinstance(child.func, ast.Name):
if child.func.id.startswith("pop_") or child.func.id in {"eegplot", "eeg_browser", "eegbrowser"}:
has_ui_call = True
break
if has_ui_call:
refresh_call = ast.Expr(
value=ast.Call(
func=ast.Name(id="refresh", ctx=ast.Load()),
args=[],
keywords=[],
)
)
ast.copy_location(refresh_call, node)
return [refresh_call, node]
return node


class _ConsoleCommandArgumentConverter(ast.NodeTransformer):
def __init__(self) -> None:
self.changed = False
Expand Down Expand Up @@ -1415,6 +1464,12 @@ def _workspace_assignment_targets(source: str) -> set[str]:
raw_targets = node.targets if isinstance(node, ast.Assign) else [node.target]
for target in raw_targets:
targets.update(root for root in _target_root_names(target) if root in WORKSPACE_NAMES)
elif isinstance(node, ast.Delete):
for target in node.targets:
targets.update(root for root in _target_root_names(target) if root in WORKSPACE_NAMES)
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
targets.update(root for root in _target_root_names(node.func.value) if root in WORKSPACE_NAMES)
return targets


Expand Down
Loading