Files
JE-Skin/devkit/sensor_server.py

602 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
JE-Skin DevKit — Python gRPC Sensor Server
提供两个服务:
1. SensorPush (streaming) — 接收实时传感器帧
2. ExportProcessor (unary) — 处理导出的 CSV 文件梯度过滤、xlsx 转换
安装依赖:
pip install grpcio grpcio-tools openpyxl
生成 gRPC 代码:
python -m grpc_tools.protoc -I../src-tauri/proto --python_out=. --grpc_python_out=. ../src-tauri/proto/sensor_stream.proto
启动:
python sensor_server.py [--port 50051]
"""
from __future__ import annotations
import argparse
import csv
import os
import signal
import statistics
import sys
import time
from concurrent import futures
from pathlib import Path
import grpc
import sensor_stream_pb2
import sensor_stream_pb2_grpc
# ── 梯度过滤逻辑(来自用户的 main.py ─────────────────────────
def load_rows(path: Path) -> list[list[str]]:
with path.open("r", encoding="utf-8-sig", newline="") as f:
return [row for row in csv.reader(f) if row]
def row_sum(row: list[str]) -> float:
return sum(float(v) for v in row[1:] if v.strip())
def find_threshold(sum_values: list[float]) -> float:
if len(sum_values) < 2:
raise ValueError("At least two rows are required.")
sorted_v = sorted(sum_values)
idx = max(
range(len(sorted_v) - 1),
key=lambda i: sorted_v[i + 1] - sorted_v[i],
)
return (sorted_v[idx] + sorted_v[idx + 1]) / 2.0
def extract_press_groups(
rows: list[list[str]], sum_values: list[float], threshold: float
) -> tuple[list[list[str]], list[float]]:
filtered: list[list[str]] = []
group_means: list[float] = []
current_group: list[float] = []
for row, total in zip(rows, sum_values):
if total >= threshold:
filtered.append(row)
current_group.append(total)
continue
if current_group:
group_means.append(statistics.fmean(current_group))
current_group = []
if current_group:
group_means.append(statistics.fmean(current_group))
return filtered, group_means
def write_csv(path: Path, rows: list[list[str]]) -> Path:
out = path.with_name(f"{path.stem}_filtered.csv")
with out.open("w", encoding="utf-8-sig", newline="") as f:
csv.writer(f).writerows(rows)
return out
def write_xlsx(path: Path, rows: list[list[str]], stats: dict) -> Path:
"""将过滤后的数据和统计信息写入 xlsx"""
try:
import openpyxl
except ImportError:
raise RuntimeError("openpyxl is required for xlsx output. Install it with: pip install openpyxl")
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
wb = openpyxl.Workbook()
# Sheet 1: 过滤后的数据
ws_data = wb.active
ws_data.title = "Filtered Data"
for row in rows:
ws_data.append([float(c) if c.strip().replace(".", "").replace("-", "").isdigit() else c for c in row])
# Sheet 2: 统计信息
ws_stats = wb.create_sheet("Statistics")
header_font = Font(bold=True, size=11)
header_fill = PatternFill(start_color="E0E0E0", end_color="E0E0E0", fill_type="solid")
ws_stats.append(["Parameter", "Value"])
ws_stats["A1"].font = header_font
ws_stats["A1"].fill = header_fill
ws_stats["B1"].font = header_font
ws_stats["B1"].fill = header_fill
stats_rows = [
("Source File", stats.get("source_file", "")),
("Total Rows", stats.get("rows_total", 0)),
("Filtered Rows", stats.get("rows_kept", 0)),
("Groups Used", stats.get("groups_used", 0)),
("Mean Value", f"{stats.get('mean_value', 0):.3f}"),
("Threshold", f"{stats.get('threshold', 0):.3f}"),
("Process Time", stats.get("process_time", "")),
]
for label, value in stats_rows:
ws_stats.append([label, value])
ws_stats.column_dimensions["A"].width = 18
ws_stats.column_dimensions["B"].width = 30
out = path.with_name(f"{path.stem}_filtered.xlsx")
wb.save(str(out))
return out
def process_csv(csv_path: str, save_as_xlsx: bool) -> dict:
"""执行梯度过滤,返回结果统计"""
path = Path(csv_path)
if not path.is_file():
raise FileNotFoundError(f"CSV file not found: {csv_path}")
rows = load_rows(path)
if not rows:
raise ValueError("CSV file is empty.")
sum_values = [row_sum(r) for r in rows]
threshold = find_threshold(sum_values)
filtered_rows, group_means = extract_press_groups(rows, sum_values, threshold)
if not filtered_rows:
raise ValueError("No large press-down data was found.")
overall_mean = statistics.fmean(group_means)
stats = {
"source_file": path.name,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"process_time": time.strftime("%Y-%m-%d %H:%M:%S"),
}
if save_as_xlsx:
output_path = write_xlsx(path, filtered_rows, stats)
# 删除源 CSV
try:
path.unlink()
except OSError:
pass
else:
output_path = write_csv(path, filtered_rows)
# 用过滤后的文件替换源文件
try:
path.unlink()
output_path.rename(path)
output_path = path
except OSError:
pass
# 追加一行到汇总 xlsx
_append_analysis_log(csv_path, stats)
return {
"ok": True,
"output_path": str(output_path),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"message": "OK",
}
def _append_analysis_log(source_csv: str, stats: dict):
"""将处理结果追加到 devkit_analysis_results.xlsx"""
try:
import openpyxl
except ImportError:
return # openpyxl 不可用时跳过
log_path = Path(source_csv).parent / "devkit_analysis_results.xlsx"
if log_path.exists():
wb = openpyxl.load_workbook(str(log_path))
ws = wb.active
else:
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "Analysis Log"
ws.append(["Time", "Source File", "Total Rows", "Kept Rows",
"Groups", "Mean Value", "Threshold", "Output File"])
ws.append([
stats.get("process_time", ""),
stats.get("source_file", ""),
stats.get("rows_total", 0),
stats.get("rows_kept", 0),
stats.get("groups_used", 0),
round(stats.get("mean_value", 0), 3),
round(stats.get("threshold", 0), 3),
f"{Path(stats.get('source_file', '')).stem}_filtered",
])
wb.save(str(log_path))
# ── gRPC 服务实现 ────────────────────────────────────────────────
class SensorPushServicer(sensor_stream_pb2_grpc.SensorPushServicer):
"""接收实时传感器帧streaming"""
_csv_path = None # 类变量,记录当前 CSV 路径
def __init__(self):
self.frame_count = 0
self.last_report_time = time.time()
self.last_angle = None
self._csv_file = None
self._csv_writer = None
def _open_csv(self):
"""打开一个新的 CSV 文件用于持续写入"""
ts = time.strftime("%Y%m%d_%H%M%S")
SensorPushServicer._csv_path = os.path.join(os.getcwd(), f"sensor_log_{ts}.csv")
self._csv_file = open(SensorPushServicer._csv_path, "w", newline="", encoding="utf-8-sig")
self._csv_writer = csv.writer(self._csv_file)
header = ["seq", "timestamp_ms", "dts_ms", "angle", "magnitude", "state", "resultant_force"] + [f"ch{i}" for i in range(SENSOR_ROWS * SENSOR_COLS)]
self._csv_writer.writerow(header)
self._csv_file.flush()
print(f"[SensorPush] CSV logging to: {SensorPushServicer._csv_path}")
def _close_csv(self):
"""关闭 CSV 文件"""
if self._csv_file:
self._csv_file.close()
print(f"[SensorPush] CSV saved: {SensorPushServicer._csv_path}")
self._csv_file = None
self._csv_writer = None
def Upload(self, request_iterator, context):
print("[SensorPush] Client connected, waiting for frames...")
reset_baseline()
self.last_angle = None
self.frame_count = 0
self._open_csv()
for frame in request_iterator:
self.frame_count += 1
angle = 0.0
magnitude = 0.0
state = 0
ok = True
message = "OK"
if len(frame.matrix) == SENSOR_ROWS * SENSOR_COLS:
try:
angle, magnitude, state = get_pzt_angle(frame.matrix)
self.last_angle = angle
# 打印接收到的数据
print(f"[Recv #{frame.seq}] seq={frame.seq} ts={frame.timestamp_ms} dts={frame.dts_ms} "
f"data={list(frame.matrix)}")
except Exception as e:
ok = False
message = str(e)
print(f"[SensorPush] PZT compute error on frame #{frame.seq}: {e}")
else:
ok = False
message = f"Invalid matrix length: {len(frame.matrix)}"
print(f"[Recv #{frame.seq}] INVALID len={len(frame.matrix)}")
# 持续写入 CSV
if self._csv_writer:
row = [frame.seq, frame.timestamp_ms, frame.dts_ms,
f"{angle:.4f}", f"{magnitude:.4f}", state, frame.resultant_force]
row += list(frame.matrix)
self._csv_writer.writerow(row)
if self.frame_count % 10 == 0:
self._csv_file.flush()
yield sensor_stream_pb2.PztAngleResponse(
seq=frame.seq,
timestamp_ms=frame.timestamp_ms,
angle=angle,
dts_ms=frame.dts_ms,
ok=ok,
message=message,
)
if self.frame_count % 100 == 0:
now = time.time()
elapsed = now - self.last_report_time
fps = 100 / elapsed if elapsed > 0 else 0
self.last_report_time = now
angle_text = (
f"{self.last_angle:.2f}"
if self.last_angle is not None
else "n/a"
)
print(
f"[SensorPush] Frame #{frame.seq} | "
f"{frame.rows}x{frame.cols} | "
f"angle={angle_text} | "
f"force={frame.resultant_force:.1f} | "
f"total={self.frame_count} | ~{fps:.1f} fps"
)
self._close_csv()
print(f"[SensorPush] Stream ended. Total: {self.frame_count}")
class ExportProcessorServicer(sensor_stream_pb2_grpc.ExportProcessorServicer):
"""处理导出的 CSV 文件unary"""
def ProcessFile(self, request, context):
csv_path = request.csv_path
save_as_xlsx = request.save_as_xlsx
print(f"[ExportProcessor] Processing: {csv_path} (xlsx={save_as_xlsx})")
try:
result = process_csv(csv_path, save_as_xlsx)
return sensor_stream_pb2.ProcessResponse(
ok=result["ok"],
output_path=result["output_path"],
groups_used=result["groups_used"],
mean_value=result["mean_value"],
threshold=result["threshold"],
rows_total=result["rows_total"],
rows_kept=result["rows_kept"],
message=result["message"],
)
except Exception as e:
print(f"[ExportProcessor] Error: {e}")
return sensor_stream_pb2.ProcessResponse(
ok=False,
output_path="",
message=str(e),
)
# ── 启动 ────────────────────────────────────────────────────────
def serve(port: int):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
sensor_stream_pb2_grpc.add_SensorPushServicer_to_server(SensorPushServicer(), server)
sensor_stream_pb2_grpc.add_ExportProcessorServicer_to_server(ExportProcessorServicer(), server)
listen_addr = f"0.0.0.0:{port}"
server.add_insecure_port(listen_addr)
server.start()
print(f"[DevKit Server] gRPC listening on {listen_addr}")
print(f"[DevKit Server] Services: SensorPush (streaming), ExportProcessor (unary)")
def shutdown(signum, frame):
print("\n[DevKit Server] Shutting down...")
server.stop(grace=5)
sys.exit(0)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
server.wait_for_termination()
import numpy as np
import threading
from collections import deque
# ===================== 算法参数=====================
COP_INIT_MEDIAN_FRAMES = 15 # 初始COP取中位数的帧数
NOISE_COLLECT_FRAMES = 20 # 动态阈值基线采集帧数
THRESH_K = 5 # 阈值 = mean + K * std
SENSOR_ROWS = 12
SENSOR_COLS = 7
# ===================== 二次静置精修参数 =====================
POST_INIT_WINDOW_CNT = 60000
POST_INIT_STABLE_CNT = 200
POST_INIT_STABLE_THRESH = 0.1
# ===================== 线程安全全局状态 =====================
first_frame = None
first_frame_lock = threading.Lock()
first_contact_CoP_x = None
first_contact_CoP_y = None
contact_initialized = False
# 候选初始CoP缓冲
cop_init_x_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
cop_init_y_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
# 动态阈值
noise_sum_buf = deque(maxlen=NOISE_COLLECT_FRAMES)
dynamic_thresh = None
# 二次静置精修状态
post_init_frame_cnt = 0
post_stable_cnt = 0
post_refined_flag = False
post_cand_x = None
post_cand_y = None
# ===================== 基线减除 =====================
def subtract_baseline(current_frame):
global first_frame
current_frame = np.array(current_frame, dtype=np.float32).flatten()
with first_frame_lock:
if first_frame is None:
first_frame = current_frame.copy()
diff = current_frame - first_frame
return np.clip(diff, 0, None)
# ===================== 重置CoP状态 =====================
def reset_cop_state():
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
global post_cand_x, post_cand_y
first_contact_CoP_x = None
first_contact_CoP_y = None
contact_initialized = False
cop_init_x_buf.clear()
cop_init_y_buf.clear()
post_init_frame_cnt = 0
post_stable_cnt = 0
post_refined_flag = False
post_cand_x = None
post_cand_y = None
# ===================== CoP压力中心计算 =====================
def compute_pressure_direction(baseline_subtracted_frame):
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
global post_cand_x, post_cand_y
global noise_sum_buf, dynamic_thresh
rows, cols = SENSOR_ROWS, SENSOR_COLS
frame_flat = np.asarray(baseline_subtracted_frame, dtype=np.float32).flatten()
frame2d = frame_flat.reshape(rows, cols)
total_pressure = np.sum(frame2d)
# 动态阈值
if dynamic_thresh is None:
noise_sum_buf.append(total_pressure)
if len(noise_sum_buf) >= NOISE_COLLECT_FRAMES:
sums = np.array(noise_sum_buf)
dynamic_thresh = float(np.mean(sums) + THRESH_K * np.std(sums))
# 低压重置
if dynamic_thresh is not None and total_pressure < dynamic_thresh:
if contact_initialized:
reset_cop_state()
return 0.0, 0.0, 0, rows-1, 0, cols-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0
if total_pressure == 0:
return 0.0, 0.0, 0, rows-1, 0, cols-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0
x_grid = np.tile(np.arange(cols), (rows, 1))
y_grid = np.repeat(np.arange(rows), cols).reshape(rows, cols)
cop_x = np.sum(frame2d * x_grid) / total_pressure
cop_y = np.sum(frame2d * y_grid) / total_pressure
delta_CoP_x = 0.0
delta_CoP_y = 0.0
base_x = cop_x
base_y = cop_y
# ============ 初始点稳定判断(中位数判定) ============
if not contact_initialized:
cop_init_x_buf.append(cop_x)
cop_init_y_buf.append(cop_y)
if len(cop_init_x_buf) >= COP_INIT_MEDIAN_FRAMES:
xs = list(cop_init_x_buf)
ys = list(cop_init_y_buf)
first_contact_CoP_x = float(np.median(xs))
first_contact_CoP_y = float(np.median(ys))
print(f"[CoP Init] 前{COP_INIT_MEDIAN_FRAMES}帧坐标:")
for i in range(len(xs)):
print(f" frame {i}: x={xs[i]:.3f}, y={ys[i]:.3f}")
print(f" 中位数: x={first_contact_CoP_x:.3f}, y={first_contact_CoP_y:.3f}")
contact_initialized = True
cop_init_x_buf.clear()
cop_init_y_buf.clear()
# ========== 计算偏移量 ==========
else:
# 二次静置精修
post_init_frame_cnt += 1
if not post_refined_flag and post_init_frame_cnt <= POST_INIT_WINDOW_CNT:
if post_cand_x is not None:
dist_val = np.hypot(cop_x - post_cand_x, cop_y - post_cand_y)
if dist_val <= POST_INIT_STABLE_THRESH:
post_stable_cnt += 1
else:
post_cand_x = cop_x
post_cand_y = cop_y
post_stable_cnt = 1
else:
post_cand_x = cop_x
post_cand_y = cop_y
post_stable_cnt = 1
if post_stable_cnt >= POST_INIT_STABLE_CNT:
first_contact_CoP_x = post_cand_x
first_contact_CoP_y = post_cand_y
post_refined_flag = True
else:
post_refined_flag = True
delta_CoP_x = cop_x - first_contact_CoP_x
delta_CoP_y = first_contact_CoP_y - cop_y
base_x = first_contact_CoP_x
base_y = first_contact_CoP_y
magnitude = np.hypot(delta_CoP_x, delta_CoP_y)
if not contact_initialized:
state = 0
elif not post_refined_flag:
state = 1
else:
state = 2
return (cop_x, cop_y,
0, rows-1, 0, cols-1,
delta_CoP_x, delta_CoP_y,
base_x, base_y,
magnitude, state)
# ===================== 角度计算核心 =====================
def compute_vector_angle(x: float, y: float) -> tuple[float, float]:
epsilon = 1e-8
mag = np.hypot(x, y)
angle = np.degrees(np.arctan2(y, x + epsilon))
if angle < 0:
angle += 360
return angle, mag
def compute_PZT_angle(Px: float, Py: float) -> tuple[float, float]:
return compute_vector_angle(Px, Py)
# ===================== 核心入口函数 =====================
def get_pzt_angle(adc_data):
if len(adc_data) != 84:
raise ValueError("ADC数据长度必须为84")
baseline_subtracted = subtract_baseline(adc_data)
result = compute_pressure_direction(baseline_subtracted)
dx, dy = result[6], result[7]
magnitude = result[10]
state = int(result[11])
pzt_angle, _ = compute_PZT_angle(dx, dy)
return pzt_angle, magnitude, state
# ===================== 重置基线(校准用) =====================
def reset_baseline():
global first_frame, dynamic_thresh
with first_frame_lock:
first_frame = None
dynamic_thresh = None
noise_sum_buf.clear()
reset_cop_state()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="JE-Skin DevKit gRPC Server")
parser.add_argument("--port", type=int, default=50051, help="gRPC listen port (default: 50051)")
args = parser.parse_args()
serve(args.port)