Files
JE-Skin/devkit/sensor_server.py
2026-06-02 09:43:05 +08:00

978 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
JE-Skin DevKit — Python gRPC Sensor Server
提供两个服务:
1. SensorPush (streaming) — 接收实时传感器帧
2. ExportProcessor (unary) — 处理导出的 CSV 文件梯度过滤、xlsx 转换
安装依赖:
pip install grpcio grpcio-tools openpyxl
生成 gRPC 代码:
python -m grpc_tools.protoc -I../src-tauri/proto --python_out=. --grpc_python_out=. ../src-tauri/proto/sensor_stream.proto
启动:
python sensor_server.py [--port 50051]
"""
from __future__ import annotations
import argparse
import csv
import os
import signal
import statistics
import sys
import time
from concurrent import futures
from pathlib import Path
import grpc
import sensor_stream_pb2
import sensor_stream_pb2_grpc
# ── 梯度过滤逻辑(来自用户的 main.py ─────────────────────────
def load_rows(path: Path) -> list[list[str]]:
with path.open("r", encoding="utf-8-sig", newline="") as f:
return [row for row in csv.reader(f) if row]
def row_sum(row: list[str]) -> float:
return sum(float(v) for v in row[1:] if v.strip())
def find_threshold(sum_values: list[float]) -> float:
if len(sum_values) < 2:
raise ValueError("At least two rows are required.")
sorted_v = sorted(sum_values)
idx = max(
range(len(sorted_v) - 1),
key=lambda i: sorted_v[i + 1] - sorted_v[i],
)
return (sorted_v[idx] + sorted_v[idx + 1]) / 2.0
def extract_press_groups(
rows: list[list[str]], sum_values: list[float], threshold: float
) -> tuple[list[list[str]], list[float]]:
filtered: list[list[str]] = []
group_means: list[float] = []
current_group: list[float] = []
for row, total in zip(rows, sum_values):
if total >= threshold:
filtered.append(row)
current_group.append(total)
continue
if current_group:
group_means.append(statistics.fmean(current_group))
current_group = []
if current_group:
group_means.append(statistics.fmean(current_group))
return filtered, group_means
def write_csv(path: Path, rows: list[list[str]]) -> Path:
out = path.with_name(f"{path.stem}_filtered.csv")
with out.open("w", encoding="utf-8-sig", newline="") as f:
csv.writer(f).writerows(rows)
return out
def write_xlsx(path: Path, rows: list[list[str]], stats: dict) -> Path:
"""将过滤后的数据和统计信息写入 xlsx"""
try:
import openpyxl
except ImportError:
raise RuntimeError("openpyxl is required for xlsx output. Install it with: pip install openpyxl")
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
wb = openpyxl.Workbook()
# Sheet 1: 过滤后的数据
ws_data = wb.active
ws_data.title = "Filtered Data"
for row in rows:
ws_data.append([float(c) if c.strip().replace(".", "").replace("-", "").isdigit() else c for c in row])
# Sheet 2: 统计信息
ws_stats = wb.create_sheet("Statistics")
header_font = Font(bold=True, size=11)
header_fill = PatternFill(start_color="E0E0E0", end_color="E0E0E0", fill_type="solid")
ws_stats.append(["Parameter", "Value"])
ws_stats["A1"].font = header_font
ws_stats["A1"].fill = header_fill
ws_stats["B1"].font = header_font
ws_stats["B1"].fill = header_fill
stats_rows = [
("Source File", stats.get("source_file", "")),
("Total Rows", stats.get("rows_total", 0)),
("Filtered Rows", stats.get("rows_kept", 0)),
("Groups Used", stats.get("groups_used", 0)),
("Mean Value", f"{stats.get('mean_value', 0):.3f}"),
("Threshold", f"{stats.get('threshold', 0):.3f}"),
("Process Time", stats.get("process_time", "")),
]
for label, value in stats_rows:
ws_stats.append([label, value])
ws_stats.column_dimensions["A"].width = 18
ws_stats.column_dimensions["B"].width = 30
out = path.with_name(f"{path.stem}_filtered.xlsx")
wb.save(str(out))
return out
def process_csv(csv_path: str, save_as_xlsx: bool) -> dict:
"""执行梯度过滤,返回结果统计"""
path = Path(csv_path)
if not path.is_file():
raise FileNotFoundError(f"CSV file not found: {csv_path}")
rows = load_rows(path)
if not rows:
raise ValueError("CSV file is empty.")
sum_values = [row_sum(r) for r in rows]
threshold = find_threshold(sum_values)
filtered_rows, group_means = extract_press_groups(rows, sum_values, threshold)
if not filtered_rows:
raise ValueError("No large press-down data was found.")
overall_mean = statistics.fmean(group_means)
stats = {
"source_file": path.name,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"process_time": time.strftime("%Y-%m-%d %H:%M:%S"),
}
if save_as_xlsx:
output_path = write_xlsx(path, filtered_rows, stats)
# 删除源 CSV
try:
path.unlink()
except OSError:
pass
else:
output_path = write_csv(path, filtered_rows)
# 用过滤后的文件替换源文件
try:
path.unlink()
output_path.rename(path)
output_path = path
except OSError:
pass
# 追加一行到汇总 xlsx
_append_analysis_log(csv_path, stats)
return {
"ok": True,
"output_path": str(output_path),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"message": "OK",
}
def _append_analysis_log(source_csv: str, stats: dict):
"""将处理结果追加到 devkit_analysis_results.xlsx"""
try:
import openpyxl
except ImportError:
return # openpyxl 不可用时跳过
log_path = Path(source_csv).parent / "devkit_analysis_results.xlsx"
if log_path.exists():
wb = openpyxl.load_workbook(str(log_path))
ws = wb.active
else:
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "Analysis Log"
ws.append(["Time", "Source File", "Total Rows", "Kept Rows",
"Groups", "Mean Value", "Threshold", "Output File"])
ws.append([
stats.get("process_time", ""),
stats.get("source_file", ""),
stats.get("rows_total", 0),
stats.get("rows_kept", 0),
stats.get("groups_used", 0),
round(stats.get("mean_value", 0), 3),
round(stats.get("threshold", 0), 3),
f"{Path(stats.get('source_file', '')).stem}_filtered",
])
wb.save(str(log_path))
# ── gRPC 服务实现 ────────────────────────────────────────────────
class SensorPushServicer(sensor_stream_pb2_grpc.SensorPushServicer):
"""接收实时传感器帧streaming"""
def __init__(self):
self.frame_count = 0
self.last_report_time = time.time()
self.last_angle = None
self.last_state = 0
self.last_magnitude = 0.0
def Upload(self, request_iterator, context):
print("[SensorPush] Client connected, waiting for frames...")
reset_baseline()
self.last_angle = None
self.last_state = 0
self.last_magnitude = 0.0
for frame in request_iterator:
self.frame_count += 1
angle = 0.0
magnitude = 0.0
state = 0
cop_x = 0.0
cop_y = 0.0
base_x = 0.0
base_y = 0.0
total_press = 0.0
threshold = 0.0
ok = True
message = "OK"
if len(frame.matrix) == SENSOR_ROWS * SENSOR_COLS:
try:
result = get_pzt_angle(frame.matrix, float(frame.dts_ms))
angle, magnitude, state, cop_x, cop_y, base_x, base_y, total_press, threshold = result
threshold = threshold or 0.0
self.last_angle = angle
self.last_state = state
self.last_magnitude = magnitude
if self.frame_count <= 10 or self.frame_count % 30 == 0:
print(
f"[SensorPush] PZT angle frame #{frame.seq} "
f"dts={frame.dts_ms} angle={angle:.2f} "
f"mag={magnitude:.2f} state={state} "
f"cop=({cop_x:.2f},{cop_y:.2f}) "
f"base=({base_x:.2f},{base_y:.2f}) "
f"total={total_press:.2f} threshold={threshold:.2f}"
)
except Exception as e:
ok = False
message = str(e)
print(f"[SensorPush] PZT compute error on frame #{frame.seq}: {e}")
else:
ok = False
message = f"Invalid matrix length: {len(frame.matrix)}"
yield sensor_stream_pb2.PztAngleResponse(
seq=frame.seq,
timestamp_ms=frame.timestamp_ms,
angle=angle,
dts_ms=frame.dts_ms,
ok=ok,
message=message,
magnitude=magnitude,
state=state,
cop_x=cop_x,
cop_y=cop_y,
base_x=base_x,
base_y=base_y,
total_press=total_press,
threshold=threshold or 0.0,
)
if self.frame_count % 100 == 0:
now = time.time()
elapsed = now - self.last_report_time
fps = 100 / elapsed if elapsed > 0 else 0
self.last_report_time = now
angle_text = (
f"{self.last_angle:.2f}"
if self.last_angle is not None
else "n/a"
)
print(
f"[SensorPush] Frame #{frame.seq} | "
f"{frame.rows}x{frame.cols} | "
f"angle={angle_text} | "
f"mag={self.last_magnitude:.2f} | "
f"state={self.last_state} | "
f"force={frame.resultant_force:.1f} | "
f"total={self.frame_count} | ~{fps:.1f} fps"
)
print(f"[SensorPush] Stream ended. Total: {self.frame_count}")
class ExportProcessorServicer(sensor_stream_pb2_grpc.ExportProcessorServicer):
"""处理导出的 CSV 文件unary"""
def ProcessFile(self, request, context):
csv_path = request.csv_path
save_as_xlsx = request.save_as_xlsx
print(f"[ExportProcessor] Processing: {csv_path} (xlsx={save_as_xlsx})")
try:
result = process_csv(csv_path, save_as_xlsx)
return sensor_stream_pb2.ProcessResponse(
ok=result["ok"],
output_path=result["output_path"],
groups_used=result["groups_used"],
mean_value=result["mean_value"],
threshold=result["threshold"],
rows_total=result["rows_total"],
rows_kept=result["rows_kept"],
message=result["message"],
)
except Exception as e:
print(f"[ExportProcessor] Error: {e}")
return sensor_stream_pb2.ProcessResponse(
ok=False,
output_path="",
message=str(e),
)
# ── 启动 ────────────────────────────────────────────────────────
def serve(port: int):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
sensor_stream_pb2_grpc.add_SensorPushServicer_to_server(SensorPushServicer(), server)
sensor_stream_pb2_grpc.add_ExportProcessorServicer_to_server(ExportProcessorServicer(), server)
listen_addr = f"0.0.0.0:{port}"
server.add_insecure_port(listen_addr)
server.start()
print(f"[DevKit Server] gRPC listening on {listen_addr}")
print(f"[DevKit Server] Services: SensorPush (streaming), ExportProcessor (unary)")
def shutdown(signum, frame):
print("\n[DevKit Server] Shutting down...")
server.stop(grace=5)
sys.exit(0)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
server.wait_for_termination()
import numpy as np
import threading
from collections import deque
# ===================== 切向力算法参数 ======================
COP_INIT_MEDIAN_FRAMES = 1
NOISE_COLLECT_FRAMES = 20
THRESH_K = 5
SENSOR_ROWS = 12
SENSOR_COLS = 7
SNAP_CENTER_X, SNAP_CENTER_Y = 3.0, 5.5
SNAP_RANGE_X = 0.0
SNAP_RANGE_Y = 0.0
POST_INIT_WINDOW_CNT = 600000
POST_INIT_STABLE_CNT = 500
POST_INIT_STABLE_THRESH = 0.1
# ===================== 线程安全全局状态 ======================
first_frame = None
first_frame_lock = threading.Lock()
first_contact_CoP_x = None
first_contact_CoP_y = None
contact_initialized = False
cop_init_x_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
cop_init_y_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
noise_sum_buf = deque(maxlen=NOISE_COLLECT_FRAMES)
dynamic_thresh = None
post_init_frame_cnt = 0
post_stable_cnt = 0
post_refined_flag = False
post_cand_x = None
post_cand_y = None
# ===================== 基线减除 =====================
def subtract_baseline(current_frame):
global first_frame
current_frame = np.array(current_frame, dtype=np.float32).flatten()
with first_frame_lock:
if first_frame is None:
first_frame = current_frame.copy()
diff = current_frame - first_frame
return np.clip(diff, 0, None)
# ===================== 重置CoP状态 =====================
def _legacy_reset_cop_state():
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
global post_cand_x, post_cand_y
first_contact_CoP_x = None
first_contact_CoP_y = None
contact_initialized = False
cop_init_x_buf.clear()
cop_init_y_buf.clear()
post_init_frame_cnt = 0
post_stable_cnt = 0
post_refined_flag = False
post_cand_x = None
post_cand_y = None
# ===================== CoP压力中心计算新算法 =====================
def _legacy_compute_pressure_direction(baseline_subtracted_frame):
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
global post_cand_x, post_cand_y
global noise_sum_buf, dynamic_thresh
rows, cols = SENSOR_ROWS, SENSOR_COLS
frame_flat = np.asarray(baseline_subtracted_frame, dtype=np.float32).flatten()
frame2d = frame_flat.reshape(rows, cols)
total_pressure = np.sum(frame2d)
if dynamic_thresh is None:
noise_sum_buf.append(total_pressure)
if len(noise_sum_buf) >= NOISE_COLLECT_FRAMES:
sums = np.array(noise_sum_buf)
dynamic_thresh = THRESH_K * float(np.mean(sums))
if total_pressure == 0 or (dynamic_thresh is not None and total_pressure < dynamic_thresh):
if contact_initialized and dynamic_thresh is not None:
_legacy_reset_cop_state()
return 0.0, 0.0, 0, rows-1, 0, cols-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0.0, dynamic_thresh
x_grid = np.tile(np.arange(cols), (rows, 1))
y_grid = np.repeat(np.arange(rows), cols).reshape(rows, cols)
cop_x = np.sum(frame2d * x_grid) / total_pressure
cop_y = np.sum(frame2d * y_grid) / total_pressure
delta_CoP_x = 0.0
delta_CoP_y = 0.0
base_x = cop_x
base_y = cop_y
if not contact_initialized:
cop_init_x_buf.append(cop_x)
cop_init_y_buf.append(cop_y)
if len(cop_init_x_buf) >= COP_INIT_MEDIAN_FRAMES:
first_contact_CoP_x = float(np.median(cop_init_x_buf))
first_contact_CoP_y = float(np.median(cop_init_y_buf))
contact_initialized = True
cop_init_x_buf.clear()
cop_init_y_buf.clear()
if (abs(first_contact_CoP_x - SNAP_CENTER_X) <= SNAP_RANGE_X and
abs(first_contact_CoP_y - SNAP_CENTER_Y) <= SNAP_RANGE_Y):
first_contact_CoP_x = SNAP_CENTER_X
first_contact_CoP_y = SNAP_CENTER_Y
else:
post_init_frame_cnt += 1
if not post_refined_flag and post_init_frame_cnt <= POST_INIT_WINDOW_CNT:
if post_cand_x is not None:
dist_val = np.hypot(cop_x - post_cand_x, cop_y - post_cand_y)
if dist_val <= POST_INIT_STABLE_THRESH:
post_stable_cnt += 1
else:
post_cand_x = cop_x
post_cand_y = cop_y
post_stable_cnt = 1
else:
post_cand_x = cop_x
post_cand_y = cop_y
post_stable_cnt = 1
if post_stable_cnt >= POST_INIT_STABLE_CNT:
first_contact_CoP_x = post_cand_x
first_contact_CoP_y = post_cand_y
post_refined_flag = True
else:
post_refined_flag = True
delta_CoP_x = cop_x - first_contact_CoP_x
delta_CoP_y = first_contact_CoP_y - cop_y
base_x = first_contact_CoP_x
base_y = first_contact_CoP_y
magnitude = np.hypot(delta_CoP_x, delta_CoP_y)
if not contact_initialized:
state = 0
elif not post_refined_flag:
state = 1
else:
state = 2
return (cop_x, cop_y,
0, rows-1, 0, cols-1,
delta_CoP_x, delta_CoP_y,
base_x, base_y,
magnitude, state,
total_pressure, dynamic_thresh)
# ===================== 角度计算核心 =====================
def _legacy_compute_vector_angle(x: float, y: float) -> tuple[float, float]:
epsilon = 1e-8
mag = np.hypot(x, y)
angle = np.degrees(np.arctan2(y, x + epsilon))
if angle < 0:
angle += 360
return angle, mag
def _legacy_compute_PZT_angle(Px: float, Py: float) -> tuple[float, float]:
return _legacy_compute_vector_angle(Px, Py)
# ===================== 核心入口函数 =====================
def _legacy_get_pzt_angle(adc_data):
if len(adc_data) != 84:
raise ValueError("ADC数据长度必须为84")
result = _legacy_compute_pressure_direction(adc_data)
cop_x, cop_y = result[0], result[1]
dx, dy = result[6], result[7]
base_x, base_y = result[8], result[9]
magnitude = result[10]
state = int(result[11])
total_press = result[12]
threshold = result[13]
pzt_angle, _ = _legacy_compute_PZT_angle(dx, dy)
return pzt_angle, magnitude, state, cop_x, cop_y, base_x, base_y, total_press, threshold
# ===================== 重置基线(校准用) =====================
def _legacy_reset_baseline():
global first_frame, noise_sum_buf, dynamic_thresh
with first_frame_lock:
first_frame = None
noise_sum_buf.clear()
dynamic_thresh = None
_legacy_reset_cop_state()
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, Tuple
ADC_LEN = SENSOR_ROWS * SENSOR_COLS
class CoPState(IntEnum):
NO_CONTACT = 0
INIT_COLLECTING = 1
POST_REFINING = 2
READY = 3
@dataclass
class CoPResult:
cop_x: float
cop_y: float
row_min: int
row_max: int
col_min: int
col_max: int
dx: float
dy: float
base_x: float
base_y: float
magnitude: float
state: int
total_pressure: float
threshold: Optional[float]
angle: float
@dataclass
class CoPConfig:
rows: int = SENSOR_ROWS
cols: int = SENSOR_COLS
noise_collect_ms: float = 300.0
thresh_k: float = 5.0
min_threshold: float = 50.0
contact_confirm_ms: float = 20.0
release_confirm_ms: float = 50.0
init_collect_ms: float = 80.0
snap_enable: bool = True
snap_center_x: float = 3.0
snap_center_y: float = 5.5
snap_range_x: float = 0.25
snap_range_y: float = 0.25
post_refine_enable: bool = True
post_refine_window_ms: float = 800.0
post_stable_ms: float = 200.0
post_stable_thresh: float = 0.1
cop_lpf_alpha: float = 0.25
epsilon: float = 1e-8
class PressureDirectionEstimator:
def __init__(self, config: CoPConfig = CoPConfig()):
self.cfg = config
self.reset_all()
def reset_all(self):
self.dynamic_thresh: Optional[float] = None
self.noise_samples = []
self.noise_start_ms: Optional[float] = None
self.reset_contact_state()
def reset_contact_state(self):
self.first_contact_cop_x: Optional[float] = None
self.first_contact_cop_y: Optional[float] = None
self.state = CoPState.NO_CONTACT
self.init_x_buf = []
self.init_y_buf = []
self.init_start_ms: Optional[float] = None
self.post_start_ms: Optional[float] = None
self.post_stable_start_ms: Optional[float] = None
self.post_cand_x: Optional[float] = None
self.post_cand_y: Optional[float] = None
self.post_refined = False
self.contact_candidate_start_ms: Optional[float] = None
self.release_candidate_start_ms: Optional[float] = None
self.filtered_cop_x: Optional[float] = None
self.filtered_cop_y: Optional[float] = None
def update(self, adc_data, timestamp_ms: float) -> CoPResult:
frame2d = self._prepare_frame(adc_data)
total_pressure = float(np.sum(frame2d))
self._update_dynamic_threshold(total_pressure, timestamp_ms)
raw_contact = self._is_raw_contact(total_pressure)
contact_valid = self._debounce_contact(raw_contact, timestamp_ms)
if not contact_valid:
self._handle_no_contact()
return self._make_empty_result(total_pressure)
cop_x, cop_y = self._compute_cop(frame2d, total_pressure)
cop_x, cop_y = self._filter_cop(cop_x, cop_y)
self._update_state_machine(cop_x, cop_y, timestamp_ms)
if self.first_contact_cop_x is None or self.first_contact_cop_y is None:
dx = 0.0
dy = 0.0
base_x = cop_x
base_y = cop_y
else:
base_x = self.first_contact_cop_x
base_y = self.first_contact_cop_y
dx = cop_x - base_x
dy = base_y - cop_y
magnitude = float(np.hypot(dx, dy))
angle = self.compute_vector_angle(dx, dy)[0]
return CoPResult(
cop_x=float(cop_x),
cop_y=float(cop_y),
row_min=0,
row_max=self.cfg.rows - 1,
col_min=0,
col_max=self.cfg.cols - 1,
dx=float(dx),
dy=float(dy),
base_x=float(base_x),
base_y=float(base_y),
magnitude=magnitude,
state=int(self.state),
total_pressure=total_pressure,
threshold=self.dynamic_thresh,
angle=float(angle),
)
def _prepare_frame(self, adc_data) -> np.ndarray:
arr = np.asarray(adc_data, dtype=np.float32).flatten()
expected_len = self.cfg.rows * self.cfg.cols
if len(arr) != expected_len:
raise ValueError(f"ADC数据长度必须为{expected_len},当前为{len(arr)}")
arr = np.clip(arr, 0, None)
return arr.reshape(self.cfg.rows, self.cfg.cols)
def _update_dynamic_threshold(self, total_pressure: float, timestamp_ms: float):
if self.dynamic_thresh is not None:
return
if self.noise_start_ms is None:
self.noise_start_ms = timestamp_ms
self.noise_samples.append(total_pressure)
if timestamp_ms - self.noise_start_ms >= self.cfg.noise_collect_ms:
samples = np.asarray(self.noise_samples, dtype=np.float32)
mean_val = float(np.mean(samples))
std_val = float(np.std(samples))
thresh = mean_val + self.cfg.thresh_k * std_val
self.dynamic_thresh = max(thresh, self.cfg.min_threshold)
def _is_raw_contact(self, total_pressure: float) -> bool:
if self.dynamic_thresh is None:
return False
return total_pressure >= self.dynamic_thresh
def _debounce_contact(self, raw_contact: bool, timestamp_ms: float) -> bool:
currently_in_contact = self.state != CoPState.NO_CONTACT
if raw_contact:
self.release_candidate_start_ms = None
if currently_in_contact:
return True
if self.contact_candidate_start_ms is None:
self.contact_candidate_start_ms = timestamp_ms
return timestamp_ms - self.contact_candidate_start_ms >= self.cfg.contact_confirm_ms
self.contact_candidate_start_ms = None
if not currently_in_contact:
return False
if self.release_candidate_start_ms is None:
self.release_candidate_start_ms = timestamp_ms
if timestamp_ms - self.release_candidate_start_ms >= self.cfg.release_confirm_ms:
return False
return True
def _handle_no_contact(self):
if self.state != CoPState.NO_CONTACT:
self.reset_contact_state()
def _compute_cop(self, frame2d: np.ndarray, total_pressure: float) -> Tuple[float, float]:
rows, cols = self.cfg.rows, self.cfg.cols
x_grid = np.tile(np.arange(cols, dtype=np.float32), (rows, 1))
y_grid = np.repeat(np.arange(rows, dtype=np.float32), cols).reshape(rows, cols)
cop_x = float(np.sum(frame2d * x_grid) / total_pressure)
cop_y = float(np.sum(frame2d * y_grid) / total_pressure)
return cop_x, cop_y
def _filter_cop(self, cop_x: float, cop_y: float) -> Tuple[float, float]:
alpha = self.cfg.cop_lpf_alpha
if alpha <= 0.0:
return cop_x, cop_y
if self.filtered_cop_x is None or self.filtered_cop_y is None:
self.filtered_cop_x = cop_x
self.filtered_cop_y = cop_y
else:
self.filtered_cop_x = alpha * cop_x + (1.0 - alpha) * self.filtered_cop_x
self.filtered_cop_y = alpha * cop_y + (1.0 - alpha) * self.filtered_cop_y
return self.filtered_cop_x, self.filtered_cop_y
def _update_state_machine(self, cop_x: float, cop_y: float, timestamp_ms: float):
if self.state == CoPState.NO_CONTACT:
self.state = CoPState.INIT_COLLECTING
self.init_start_ms = timestamp_ms
self.init_x_buf.clear()
self.init_y_buf.clear()
if self.state == CoPState.INIT_COLLECTING:
self.init_x_buf.append(cop_x)
self.init_y_buf.append(cop_y)
if self.init_start_ms is None:
self.init_start_ms = timestamp_ms
if timestamp_ms - self.init_start_ms >= self.cfg.init_collect_ms:
base_x = float(np.median(self.init_x_buf))
base_y = float(np.median(self.init_y_buf))
base_x, base_y = self._apply_center_snap(base_x, base_y)
self.first_contact_cop_x = base_x
self.first_contact_cop_y = base_y
self.post_start_ms = timestamp_ms
self.post_cand_x = None
self.post_cand_y = None
self.post_stable_start_ms = None
if self.cfg.post_refine_enable:
self.state = CoPState.POST_REFINING
else:
self.post_refined = True
self.state = CoPState.READY
return
if self.state == CoPState.POST_REFINING:
self._post_refine(cop_x, cop_y, timestamp_ms)
def _apply_center_snap(self, base_x: float, base_y: float) -> Tuple[float, float]:
if not self.cfg.snap_enable:
return base_x, base_y
if (
abs(base_x - self.cfg.snap_center_x) <= self.cfg.snap_range_x
and abs(base_y - self.cfg.snap_center_y) <= self.cfg.snap_range_y
):
return self.cfg.snap_center_x, self.cfg.snap_center_y
return base_x, base_y
def _post_refine(self, cop_x: float, cop_y: float, timestamp_ms: float):
if self.post_start_ms is None:
self.post_start_ms = timestamp_ms
if timestamp_ms - self.post_start_ms >= self.cfg.post_refine_window_ms:
self.post_refined = True
self.state = CoPState.READY
return
if self.post_cand_x is None or self.post_cand_y is None:
self.post_cand_x = cop_x
self.post_cand_y = cop_y
self.post_stable_start_ms = timestamp_ms
return
dist = float(np.hypot(cop_x - self.post_cand_x, cop_y - self.post_cand_y))
if dist <= self.cfg.post_stable_thresh:
if self.post_stable_start_ms is None:
self.post_stable_start_ms = timestamp_ms
if timestamp_ms - self.post_stable_start_ms >= self.cfg.post_stable_ms:
refined_x, refined_y = self._apply_center_snap(self.post_cand_x, self.post_cand_y)
self.first_contact_cop_x = float(refined_x)
self.first_contact_cop_y = float(refined_y)
self.post_refined = True
self.state = CoPState.READY
else:
self.post_cand_x = cop_x
self.post_cand_y = cop_y
self.post_stable_start_ms = timestamp_ms
def _make_empty_result(self, total_pressure: float) -> CoPResult:
return CoPResult(
cop_x=0.0,
cop_y=0.0,
row_min=0,
row_max=self.cfg.rows - 1,
col_min=0,
col_max=self.cfg.cols - 1,
dx=0.0,
dy=0.0,
base_x=0.0,
base_y=0.0,
magnitude=0.0,
state=int(CoPState.NO_CONTACT),
total_pressure=float(total_pressure),
threshold=self.dynamic_thresh,
angle=0.0,
)
def compute_vector_angle(self, x: float, y: float) -> Tuple[float, float]:
mag = float(np.hypot(x, y))
angle = float(np.degrees(np.arctan2(y, x + self.cfg.epsilon)))
if angle < 0:
angle += 360.0
return angle, mag
_estimator = PressureDirectionEstimator()
def reset_cop_state():
_estimator.reset_contact_state()
def reset_all_state():
_estimator.reset_all()
def compute_pressure_direction(adc_data, timestamp_ms: float):
result = _estimator.update(adc_data, timestamp_ms)
return (
result.cop_x,
result.cop_y,
result.row_min,
result.row_max,
result.col_min,
result.col_max,
result.dx,
result.dy,
result.base_x,
result.base_y,
result.magnitude,
result.state,
result.total_pressure,
result.threshold,
)
def compute_vector_angle(x: float, y: float) -> Tuple[float, float]:
return _estimator.compute_vector_angle(x, y)
def compute_PZT_angle(Px: float, Py: float) -> Tuple[float, float]:
return compute_vector_angle(Px, Py)
def get_pzt_angle(adc_data, timestamp_ms: float):
if len(adc_data) != ADC_LEN:
raise ValueError(f"ADC数据长度必须为{ADC_LEN}")
result = _estimator.update(adc_data, timestamp_ms)
return (
result.angle,
result.magnitude,
result.state,
result.cop_x,
result.cop_y,
result.base_x,
result.base_y,
result.total_pressure,
result.threshold,
)
def reset_baseline():
reset_all_state()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="JE-Skin DevKit gRPC Server")
parser.add_argument("--port", type=int, default=50051, help="gRPC listen port (default: 50051)")
args = parser.parse_args()
serve(args.port)