Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ jobs:
run: |
uv run --no-sync pytest tests

- name: Sync Parity Status
if: steps.matlab-engine-availability.outputs.available == 'true'
run: |
uv run --no-sync python tools/sync_parity_status.py

- name: Commit and Push Parity Matrix
if: steps.matlab-engine-availability.outputs.available == 'true' && github.ref == 'refs/heads/master' && success()
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
git add docs/parity/eeglab_core_parity_matrix.json
git diff --quiet && git diff --staged --quiet || (git commit -m "chore: sync parity matrix based on live test outcomes" && git push)

- name: Display installed packages
if: always()
run: |
Expand Down
26 changes: 26 additions & 0 deletions src/eegprep/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,32 @@ def compare_eeg(a, b, rtol=0, atol=1e-7, use_32_bit=default_32_bit, err_msg=''):
summary = "\n".join(summary_lines)
print(f"Actual differences: rtol: {max_rel_diff}, atol: {max_abs_diff}")

import re
import json
from pathlib import Path

# Try to extract function name from err_msg for parity tracking
match = re.search(r'([a-zA-Z0-9_]+)\(\)', err_msg)
if match:
func_name = match.group(1)
metrics_file = Path('.parity_metrics.json')
metrics = {}
if metrics_file.exists():
try:
metrics = json.loads(metrics_file.read_text())
except Exception:
pass

current = metrics.get(func_name, {})
# Keep worst-case RMS difference
if float(rms_diff) > float(current.get('rms_diff', -1)):
metrics[func_name] = {
'rms_diff': float(rms_diff),
'mean_diff': float(mean_abs_diff),
'max_diff': float(max_abs_diff)
}
metrics_file.write_text(json.dumps(metrics, indent=2))

# Perform the assertion
try:
np.testing.assert_allclose(a_flat, b_flat, rtol=rtol, atol=atol, err_msg=err_msg)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_eeg_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,33 @@ def _compare_eeg_results(self, py_result, ml_result):
self.assertLess(max_abs_diff, 5e-2, f"Max absolute difference: {max_abs_diff}")
self.assertLess(max_rel_diff, 5e-2, f"Max relative difference: {max_rel_diff}")

# Record metrics for parity
diff = py_result['data'].flatten() - ml_result['data'].flatten()
abs_diff = np.abs(diff)
rms_diff = np.sqrt(np.mean(diff**2))
mean_abs_diff = np.mean(abs_diff)
max_diff = np.max(abs_diff)

import json
import pathlib
metrics_file = pathlib.Path('.parity_metrics.json')
metrics = {}
if metrics_file.exists():
try:
metrics = json.loads(metrics_file.read_text())
except Exception:
pass

func_name = 'eeg_interp'
current = metrics.get(func_name, {})
if float(rms_diff) > float(current.get('rms_diff', -1)):
metrics[func_name] = {
'rms_diff': float(rms_diff),
'mean_diff': float(mean_abs_diff),
'max_diff': float(max_diff)
}
metrics_file.write_text(json.dumps(metrics, indent=2))

# Compare structure fields
self.assertEqual(py_result['nbchan'], ml_result['nbchan'])
self.assertEqual(py_result['pnts'], ml_result['pnts'])
Expand Down
87 changes: 87 additions & 0 deletions tools/sync_parity_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import sys
import os
from pathlib import Path

def main():
matrix_path = Path("docs/parity/eeglab_core_parity_matrix.json")
metrics_path = Path(".parity_metrics.json")

if not matrix_path.exists():
print(f"Matrix not found at {matrix_path}")
return 1

if not metrics_path.exists():
print("No parity metrics found.")
return 0

with open(matrix_path, "r", encoding="utf-8") as f:
matrix = json.load(f)

with open(metrics_path, "r", encoding="utf-8") as f:
metrics = json.load(f)

has_regression = False
updated = False

for row in matrix.get("rows", []):
func_name = row.get("eeglab_name")
if func_name in metrics:
new_m = metrics[func_name]

# Get configurable thresholds
thresholds = row.get("thresholds", {})
max_allowed_rms = thresholds.get("rms_diff", 1e-5)
max_allowed_max = thresholds.get("max_diff", 1e-5)

old_m = row.get("metrics", {})

rms_exceeded = new_m["rms_diff"] > max_allowed_rms
max_exceeded = new_m["max_diff"] > max_allowed_max

# Did it regress compared to what was previously recorded?
if old_m:
if new_m["rms_diff"] > old_m.get("rms_diff", max_allowed_rms) * 1.5:
rms_exceeded = True
if new_m["max_diff"] > old_m.get("max_diff", max_allowed_max) * 1.5:
max_exceeded = True

if rms_exceeded or max_exceeded:
if row.get("status") == "implemented":
state = "Regressed"
has_regression = True
else:
state = "In Progress"

print(f"[{state}] {func_name}: RMS={new_m['rms_diff']} (limit={max_allowed_rms}), Max={new_m['max_diff']} (limit={max_allowed_max})")

# If it's in progress, we can still record its metrics if it improved
if state == "In Progress":
if not old_m or new_m["rms_diff"] < old_m.get("rms_diff", float('inf')) or new_m["max_diff"] < old_m.get("max_diff", float('inf')):
row["metrics"] = new_m
updated = True
else:
state = "Verified Parity"
print(f"[{state}] {func_name}: RMS={new_m['rms_diff']}, Max={new_m['max_diff']}")
# Only update if improved or not recorded
if new_m["rms_diff"] < old_m.get("rms_diff", float('inf')) or new_m["max_diff"] < old_m.get("max_diff", float('inf')) or not old_m:
row["metrics"] = new_m
row["status"] = "implemented"
updated = True
else:
row["metrics"] = new_m
updated = True

if updated:
with open(matrix_path, "w", encoding="utf-8") as f:
json.dump(matrix, f, indent=2)
f.write("\n")

if has_regression:
print("Build failed due to parity regressions.")
return 1

return 0

if __name__ == "__main__":
sys.exit(main())
Loading