本项目用于构建并训练一个从人脸 blendshape(52维) 到 18维舵机指令 的映射模型,主要流程为:
- 控制机器人随机动作并采集图像/指令对
- 从图像提取 MediaPipe blendshape
- 训练 Attention + Vectorized KAN 回归模型
- 在测试集评估整体误差和各舵机误差
generate_dataset.py
采集数据集:控制舵机 + OpenCV 采图,输出同名配对文件dataset/images/sample_XXXXXX.jpgdataset/commands/sample_XXXXXX.txt
servo_constraints_fixed.py
18个舵机配置、范围裁剪和动作约束规则preprocess_blendshape_to_pt.py
用 MediaPipe Tasks 提取 52维 blendshape,并生成train.pt/test.ptmodel_attention_kan_jhr_vectorized.py
模型定义(Attention 编码器 + 向量化 RBF-KAN 回归头)与约束损失train_jhr.py
训练脚本,输出outputs_jhr_attention_kan/best.pt和last.pttest_jhr.py
测试脚本,输出测试集指标与各舵机 MAEmodels/face_landmarker.task
MediaPipe Face Landmarker 模型文件
建议 Python 3.9+。
安装依赖:
pip install numpy torch opencv-python mediapipe pyserial tqdm pillow采集前请根据你的硬件修改:
generate_dataset.py中串口:port='COM3'generate_dataset.py中相机索引:cv2.VideoCapture(0)
python generate_dataset.py输出目录:
dataset/images/*.jpgdataset/commands/*.txt
图像与指令一一对应,文件名一致(仅扩展名不同)。
python preprocess_blendshape_to_pt.py ^
--dataset_dir dataset ^
--images_subdir images ^
--commands_subdir commands ^
--model_path models/face_landmarker.task ^
--out_train train.pt ^
--out_test test.pt ^
--y_format raw说明:
--y_format raw:保存原始舵机 ticks--y_format norm:保存到[-1,1]归一化空间
python train_jhr.py默认输出:
outputs_jhr_attention_kan/best.ptoutputs_jhr_attention_kan/last.pt
python test_jhr.py默认读取:
- checkpoint:
outputs_jhr_attention_kan/best.pt - 测试集:
test.pt
输出:
test_loss / test_mse / test_mae_ticks- 各舵机 MAE(按误差从大到小排序)
- 单样本指令文件:
sample_XXXXXX.txt- 一行,
18个浮点数(舵机指令)
- 一行,
- 训练数据
train.pt/test.pt:X:[N, 52]float32(blendshape)Y:[N, 18]float32(raw 或 norm)
-
ModuleNotFoundError: No module named 'servo_constraints'
参考“运行前注意事项”,统一约束文件导入名。 -
FaceLandmarker model not found
检查--model_path,确认models/face_landmarker.task存在。 -
相机打不开或串口打不开
检查设备连接、端口号、相机索引、权限占用情况。