1331 lines
46 KiB
Python
1331 lines
46 KiB
Python
"""
|
||
JE-Skin DevKit — Python gRPC Sensor Server
|
||
|
||
提供两个服务:
|
||
1. SensorPush (streaming) — 接收实时传感器帧
|
||
2. ExportProcessor (unary) — 处理导出的 CSV 文件:梯度过滤、xlsx 转换
|
||
|
||
安装依赖:
|
||
pip install grpcio grpcio-tools openpyxl
|
||
|
||
生成 gRPC 代码:
|
||
python -m grpc_tools.protoc -I../src-tauri/proto --python_out=. --grpc_python_out=. ../src-tauri/proto/sensor_stream.proto
|
||
|
||
启动:
|
||
python sensor_server.py [--port 50051]
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import os
|
||
import signal
|
||
import statistics
|
||
import sys
|
||
import time
|
||
from concurrent import futures
|
||
from pathlib import Path
|
||
|
||
import grpc
|
||
import sensor_stream_pb2
|
||
import sensor_stream_pb2_grpc
|
||
|
||
# ── 梯度过滤逻辑(来自用户的 main.py) ─────────────────────────
|
||
|
||
|
||
def load_rows(path: Path) -> list[list[str]]:
|
||
with path.open("r", encoding="utf-8-sig", newline="") as f:
|
||
return [row for row in csv.reader(f) if row]
|
||
|
||
|
||
def row_sum(row: list[str]) -> float:
|
||
return sum(float(v) for v in row[1:] if v.strip())
|
||
|
||
|
||
def find_threshold(sum_values: list[float]) -> float:
|
||
if len(sum_values) < 2:
|
||
raise ValueError("At least two rows are required.")
|
||
sorted_v = sorted(sum_values)
|
||
idx = max(
|
||
range(len(sorted_v) - 1),
|
||
key=lambda i: sorted_v[i + 1] - sorted_v[i],
|
||
)
|
||
return (sorted_v[idx] + sorted_v[idx + 1]) / 2.0
|
||
|
||
|
||
def extract_press_groups(
|
||
rows: list[list[str]], sum_values: list[float], threshold: float
|
||
) -> tuple[list[list[str]], list[float]]:
|
||
filtered: list[list[str]] = []
|
||
group_means: list[float] = []
|
||
current_group: list[float] = []
|
||
|
||
for row, total in zip(rows, sum_values):
|
||
if total >= threshold:
|
||
filtered.append(row)
|
||
current_group.append(total)
|
||
continue
|
||
if current_group:
|
||
group_means.append(statistics.fmean(current_group))
|
||
current_group = []
|
||
|
||
if current_group:
|
||
group_means.append(statistics.fmean(current_group))
|
||
|
||
return filtered, group_means
|
||
|
||
|
||
def write_csv(path: Path, rows: list[list[str]]) -> Path:
|
||
out = path.with_name(f"{path.stem}_filtered.csv")
|
||
with out.open("w", encoding="utf-8-sig", newline="") as f:
|
||
csv.writer(f).writerows(rows)
|
||
return out
|
||
|
||
|
||
def write_xlsx(path: Path, rows: list[list[str]], stats: dict) -> Path:
|
||
"""将过滤后的数据和统计信息写入 xlsx"""
|
||
try:
|
||
import openpyxl
|
||
except ImportError:
|
||
raise RuntimeError("openpyxl is required for xlsx output. Install it with: pip install openpyxl")
|
||
|
||
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
|
||
|
||
wb = openpyxl.Workbook()
|
||
|
||
# Sheet 1: 过滤后的数据
|
||
ws_data = wb.active
|
||
ws_data.title = "Filtered Data"
|
||
for row in rows:
|
||
ws_data.append([float(c) if c.strip().replace(".", "").replace("-", "").isdigit() else c for c in row])
|
||
|
||
# Sheet 2: 统计信息
|
||
ws_stats = wb.create_sheet("Statistics")
|
||
header_font = Font(bold=True, size=11)
|
||
header_fill = PatternFill(start_color="E0E0E0", end_color="E0E0E0", fill_type="solid")
|
||
|
||
ws_stats.append(["Parameter", "Value"])
|
||
ws_stats["A1"].font = header_font
|
||
ws_stats["A1"].fill = header_fill
|
||
ws_stats["B1"].font = header_font
|
||
ws_stats["B1"].fill = header_fill
|
||
|
||
stats_rows = [
|
||
("Source File", stats.get("source_file", "")),
|
||
("Total Rows", stats.get("rows_total", 0)),
|
||
("Filtered Rows", stats.get("rows_kept", 0)),
|
||
("Groups Used", stats.get("groups_used", 0)),
|
||
("Mean Value", f"{stats.get('mean_value', 0):.3f}"),
|
||
("Threshold", f"{stats.get('threshold', 0):.3f}"),
|
||
("Process Time", stats.get("process_time", "")),
|
||
]
|
||
for label, value in stats_rows:
|
||
ws_stats.append([label, value])
|
||
|
||
ws_stats.column_dimensions["A"].width = 18
|
||
ws_stats.column_dimensions["B"].width = 30
|
||
|
||
out = path.with_name(f"{path.stem}_filtered.xlsx")
|
||
wb.save(str(out))
|
||
return out
|
||
|
||
|
||
def process_csv(csv_path: str, save_as_xlsx: bool) -> dict:
|
||
"""执行梯度过滤,返回结果统计"""
|
||
path = Path(csv_path)
|
||
if not path.is_file():
|
||
raise FileNotFoundError(f"CSV file not found: {csv_path}")
|
||
|
||
rows = load_rows(path)
|
||
if not rows:
|
||
raise ValueError("CSV file is empty.")
|
||
|
||
sum_values = [row_sum(r) for r in rows]
|
||
threshold = find_threshold(sum_values)
|
||
filtered_rows, group_means = extract_press_groups(rows, sum_values, threshold)
|
||
|
||
if not filtered_rows:
|
||
raise ValueError("No large press-down data was found.")
|
||
|
||
overall_mean = statistics.fmean(group_means)
|
||
|
||
stats = {
|
||
"source_file": path.name,
|
||
"rows_total": len(rows),
|
||
"rows_kept": len(filtered_rows),
|
||
"groups_used": len(group_means),
|
||
"mean_value": overall_mean,
|
||
"threshold": threshold,
|
||
"process_time": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
}
|
||
|
||
if save_as_xlsx:
|
||
output_path = write_xlsx(path, filtered_rows, stats)
|
||
# 删除源 CSV
|
||
try:
|
||
path.unlink()
|
||
except OSError:
|
||
pass
|
||
else:
|
||
output_path = write_csv(path, filtered_rows)
|
||
# 用过滤后的文件替换源文件
|
||
try:
|
||
path.unlink()
|
||
output_path.rename(path)
|
||
output_path = path
|
||
except OSError:
|
||
pass
|
||
|
||
# 追加一行到汇总 xlsx
|
||
_append_analysis_log(csv_path, stats)
|
||
|
||
return {
|
||
"ok": True,
|
||
"output_path": str(output_path),
|
||
"groups_used": len(group_means),
|
||
"mean_value": overall_mean,
|
||
"threshold": threshold,
|
||
"rows_total": len(rows),
|
||
"rows_kept": len(filtered_rows),
|
||
"message": "OK",
|
||
}
|
||
|
||
|
||
def _append_analysis_log(source_csv: str, stats: dict):
|
||
"""将处理结果追加到 devkit_analysis_results.xlsx"""
|
||
try:
|
||
import openpyxl
|
||
except ImportError:
|
||
return # openpyxl 不可用时跳过
|
||
|
||
log_path = Path(source_csv).parent / "devkit_analysis_results.xlsx"
|
||
|
||
if log_path.exists():
|
||
wb = openpyxl.load_workbook(str(log_path))
|
||
ws = wb.active
|
||
else:
|
||
wb = openpyxl.Workbook()
|
||
ws = wb.active
|
||
ws.title = "Analysis Log"
|
||
ws.append(["Time", "Source File", "Total Rows", "Kept Rows",
|
||
"Groups", "Mean Value", "Threshold", "Output File"])
|
||
|
||
ws.append([
|
||
stats.get("process_time", ""),
|
||
stats.get("source_file", ""),
|
||
stats.get("rows_total", 0),
|
||
stats.get("rows_kept", 0),
|
||
stats.get("groups_used", 0),
|
||
round(stats.get("mean_value", 0), 3),
|
||
round(stats.get("threshold", 0), 3),
|
||
f"{Path(stats.get('source_file', '')).stem}_filtered",
|
||
])
|
||
|
||
wb.save(str(log_path))
|
||
|
||
|
||
# ── gRPC 服务实现 ────────────────────────────────────────────────
|
||
|
||
|
||
class SensorPushServicer(sensor_stream_pb2_grpc.SensorPushServicer):
|
||
"""接收实时传感器帧(streaming)"""
|
||
|
||
def __init__(self):
|
||
self.frame_count = 0
|
||
self.last_report_time = time.time()
|
||
self.last_angle = None
|
||
self.last_state = 0
|
||
self.last_magnitude = 0.0
|
||
|
||
def Upload(self, request_iterator, context):
|
||
print("[SensorPush] Client connected, waiting for frames...")
|
||
reset_baseline()
|
||
self.last_angle = None
|
||
self.last_state = 0
|
||
self.last_magnitude = 0.0
|
||
|
||
for frame in request_iterator:
|
||
self.frame_count += 1
|
||
angle = 0.0
|
||
magnitude = 0.0
|
||
state = 0
|
||
cop_x = 0.0
|
||
cop_y = 0.0
|
||
base_x = 0.0
|
||
base_y = 0.0
|
||
total_press = 0.0
|
||
threshold = 0.0
|
||
ok = True
|
||
message = "OK"
|
||
if len(frame.matrix) == SENSOR_ROWS * SENSOR_COLS:
|
||
try:
|
||
result = get_pzt_angle(frame.matrix, float(frame.dts_ms))
|
||
angle, magnitude, state, cop_x, cop_y, base_x, base_y, total_press, threshold = result
|
||
threshold = threshold or 0.0
|
||
self.last_angle = angle
|
||
self.last_state = state
|
||
self.last_magnitude = magnitude
|
||
if self.frame_count <= 10 or self.frame_count % 30 == 0:
|
||
print(
|
||
f"[SensorPush] PZT angle frame #{frame.seq} "
|
||
f"dts={frame.dts_ms} angle={angle:.2f} "
|
||
f"mag={magnitude:.2f} state={state} "
|
||
f"cop=({cop_x:.2f},{cop_y:.2f}) "
|
||
f"base=({base_x:.2f},{base_y:.2f}) "
|
||
f"total={total_press:.2f} threshold={threshold:.2f}"
|
||
)
|
||
except Exception as e:
|
||
ok = False
|
||
message = str(e)
|
||
print(f"[SensorPush] PZT compute error on frame #{frame.seq}: {e}")
|
||
else:
|
||
ok = False
|
||
message = f"Invalid matrix length: {len(frame.matrix)}"
|
||
|
||
yield sensor_stream_pb2.PztAngleResponse(
|
||
seq=frame.seq,
|
||
timestamp_ms=frame.timestamp_ms,
|
||
angle=angle,
|
||
dts_ms=frame.dts_ms,
|
||
ok=ok,
|
||
message=message,
|
||
magnitude=magnitude,
|
||
state=state,
|
||
cop_x=cop_x,
|
||
cop_y=cop_y,
|
||
base_x=base_x,
|
||
base_y=base_y,
|
||
total_press=total_press,
|
||
threshold=threshold or 0.0,
|
||
)
|
||
|
||
if self.frame_count % 100 == 0:
|
||
now = time.time()
|
||
elapsed = now - self.last_report_time
|
||
fps = 100 / elapsed if elapsed > 0 else 0
|
||
self.last_report_time = now
|
||
angle_text = (
|
||
f"{self.last_angle:.2f}"
|
||
if self.last_angle is not None
|
||
else "n/a"
|
||
)
|
||
print(
|
||
f"[SensorPush] Frame #{frame.seq} | "
|
||
f"{frame.rows}x{frame.cols} | "
|
||
f"angle={angle_text} | "
|
||
f"mag={self.last_magnitude:.2f} | "
|
||
f"state={self.last_state} | "
|
||
f"force={frame.resultant_force:.1f} | "
|
||
f"total={self.frame_count} | ~{fps:.1f} fps"
|
||
)
|
||
|
||
print(f"[SensorPush] Stream ended. Total: {self.frame_count}")
|
||
|
||
|
||
class ExportProcessorServicer(sensor_stream_pb2_grpc.ExportProcessorServicer):
|
||
"""处理导出的 CSV 文件(unary)"""
|
||
|
||
def ProcessFile(self, request, context):
|
||
csv_path = request.csv_path
|
||
save_as_xlsx = request.save_as_xlsx
|
||
|
||
print(f"[ExportProcessor] Processing: {csv_path} (xlsx={save_as_xlsx})")
|
||
|
||
try:
|
||
result = process_csv(csv_path, save_as_xlsx)
|
||
return sensor_stream_pb2.ProcessResponse(
|
||
ok=result["ok"],
|
||
output_path=result["output_path"],
|
||
groups_used=result["groups_used"],
|
||
mean_value=result["mean_value"],
|
||
threshold=result["threshold"],
|
||
rows_total=result["rows_total"],
|
||
rows_kept=result["rows_kept"],
|
||
message=result["message"],
|
||
)
|
||
except Exception as e:
|
||
print(f"[ExportProcessor] Error: {e}")
|
||
return sensor_stream_pb2.ProcessResponse(
|
||
ok=False,
|
||
output_path="",
|
||
message=str(e),
|
||
)
|
||
|
||
|
||
# ── 启动 ────────────────────────────────────────────────────────
|
||
|
||
|
||
def serve(port: int):
|
||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||
sensor_stream_pb2_grpc.add_SensorPushServicer_to_server(SensorPushServicer(), server)
|
||
sensor_stream_pb2_grpc.add_ExportProcessorServicer_to_server(ExportProcessorServicer(), server)
|
||
|
||
listen_addr = f"0.0.0.0:{port}"
|
||
server.add_insecure_port(listen_addr)
|
||
server.start()
|
||
|
||
print(f"[DevKit Server] gRPC listening on {listen_addr}")
|
||
print(f"[DevKit Server] Services: SensorPush (streaming), ExportProcessor (unary)")
|
||
|
||
def shutdown(signum, frame):
|
||
print("\n[DevKit Server] Shutting down...")
|
||
server.stop(grace=5)
|
||
sys.exit(0)
|
||
|
||
signal.signal(signal.SIGINT, shutdown)
|
||
signal.signal(signal.SIGTERM, shutdown)
|
||
|
||
server.wait_for_termination()
|
||
|
||
|
||
import numpy as np
|
||
import threading
|
||
from collections import deque
|
||
|
||
# ===================== 切向力算法参数 ======================
|
||
COP_INIT_MEDIAN_FRAMES = 1
|
||
NOISE_COLLECT_FRAMES = 20
|
||
THRESH_K = 5
|
||
SENSOR_ROWS = 12
|
||
SENSOR_COLS = 7
|
||
|
||
SNAP_CENTER_X, SNAP_CENTER_Y = 3.0, 5.5
|
||
SNAP_RANGE_X = 0.0
|
||
SNAP_RANGE_Y = 0.0
|
||
|
||
POST_INIT_WINDOW_CNT = 600000
|
||
POST_INIT_STABLE_CNT = 500
|
||
POST_INIT_STABLE_THRESH = 0.1
|
||
|
||
# ===================== 线程安全全局状态 ======================
|
||
first_frame = None
|
||
first_frame_lock = threading.Lock()
|
||
|
||
first_contact_CoP_x = None
|
||
first_contact_CoP_y = None
|
||
contact_initialized = False
|
||
|
||
cop_init_x_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
|
||
cop_init_y_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
|
||
|
||
noise_sum_buf = deque(maxlen=NOISE_COLLECT_FRAMES)
|
||
dynamic_thresh = None
|
||
|
||
post_init_frame_cnt = 0
|
||
post_stable_cnt = 0
|
||
post_refined_flag = False
|
||
post_cand_x = None
|
||
post_cand_y = None
|
||
|
||
# ===================== 基线减除 =====================
|
||
def subtract_baseline(current_frame):
|
||
global first_frame
|
||
current_frame = np.array(current_frame, dtype=np.float32).flatten()
|
||
|
||
with first_frame_lock:
|
||
if first_frame is None:
|
||
first_frame = current_frame.copy()
|
||
|
||
diff = current_frame - first_frame
|
||
return np.clip(diff, 0, None)
|
||
|
||
# ===================== 重置CoP状态 =====================
|
||
def _legacy_reset_cop_state():
|
||
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
|
||
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
|
||
global post_cand_x, post_cand_y
|
||
|
||
first_contact_CoP_x = None
|
||
first_contact_CoP_y = None
|
||
contact_initialized = False
|
||
cop_init_x_buf.clear()
|
||
cop_init_y_buf.clear()
|
||
post_init_frame_cnt = 0
|
||
post_stable_cnt = 0
|
||
post_refined_flag = False
|
||
post_cand_x = None
|
||
post_cand_y = None
|
||
|
||
|
||
# ===================== CoP压力中心计算(新算法) =====================
|
||
def _legacy_compute_pressure_direction(baseline_subtracted_frame):
|
||
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
|
||
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
|
||
global post_cand_x, post_cand_y
|
||
global noise_sum_buf, dynamic_thresh
|
||
|
||
rows, cols = SENSOR_ROWS, SENSOR_COLS
|
||
frame_flat = np.asarray(baseline_subtracted_frame, dtype=np.float32).flatten()
|
||
frame2d = frame_flat.reshape(rows, cols)
|
||
|
||
total_pressure = np.sum(frame2d)
|
||
|
||
if dynamic_thresh is None:
|
||
noise_sum_buf.append(total_pressure)
|
||
if len(noise_sum_buf) >= NOISE_COLLECT_FRAMES:
|
||
sums = np.array(noise_sum_buf)
|
||
dynamic_thresh = THRESH_K * float(np.mean(sums))
|
||
|
||
if total_pressure == 0 or (dynamic_thresh is not None and total_pressure < dynamic_thresh):
|
||
if contact_initialized and dynamic_thresh is not None:
|
||
_legacy_reset_cop_state()
|
||
return 0.0, 0.0, 0, rows-1, 0, cols-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0.0, dynamic_thresh
|
||
|
||
x_grid = np.tile(np.arange(cols), (rows, 1))
|
||
y_grid = np.repeat(np.arange(rows), cols).reshape(rows, cols)
|
||
cop_x = np.sum(frame2d * x_grid) / total_pressure
|
||
cop_y = np.sum(frame2d * y_grid) / total_pressure
|
||
|
||
delta_CoP_x = 0.0
|
||
delta_CoP_y = 0.0
|
||
base_x = cop_x
|
||
base_y = cop_y
|
||
|
||
if not contact_initialized:
|
||
cop_init_x_buf.append(cop_x)
|
||
cop_init_y_buf.append(cop_y)
|
||
|
||
if len(cop_init_x_buf) >= COP_INIT_MEDIAN_FRAMES:
|
||
first_contact_CoP_x = float(np.median(cop_init_x_buf))
|
||
first_contact_CoP_y = float(np.median(cop_init_y_buf))
|
||
contact_initialized = True
|
||
cop_init_x_buf.clear()
|
||
cop_init_y_buf.clear()
|
||
if (abs(first_contact_CoP_x - SNAP_CENTER_X) <= SNAP_RANGE_X and
|
||
abs(first_contact_CoP_y - SNAP_CENTER_Y) <= SNAP_RANGE_Y):
|
||
first_contact_CoP_x = SNAP_CENTER_X
|
||
first_contact_CoP_y = SNAP_CENTER_Y
|
||
|
||
else:
|
||
post_init_frame_cnt += 1
|
||
if not post_refined_flag and post_init_frame_cnt <= POST_INIT_WINDOW_CNT:
|
||
if post_cand_x is not None:
|
||
dist_val = np.hypot(cop_x - post_cand_x, cop_y - post_cand_y)
|
||
if dist_val <= POST_INIT_STABLE_THRESH:
|
||
post_stable_cnt += 1
|
||
else:
|
||
post_cand_x = cop_x
|
||
post_cand_y = cop_y
|
||
post_stable_cnt = 1
|
||
else:
|
||
post_cand_x = cop_x
|
||
post_cand_y = cop_y
|
||
post_stable_cnt = 1
|
||
|
||
if post_stable_cnt >= POST_INIT_STABLE_CNT:
|
||
first_contact_CoP_x = post_cand_x
|
||
first_contact_CoP_y = post_cand_y
|
||
post_refined_flag = True
|
||
else:
|
||
post_refined_flag = True
|
||
|
||
delta_CoP_x = cop_x - first_contact_CoP_x
|
||
delta_CoP_y = first_contact_CoP_y - cop_y
|
||
base_x = first_contact_CoP_x
|
||
base_y = first_contact_CoP_y
|
||
|
||
magnitude = np.hypot(delta_CoP_x, delta_CoP_y)
|
||
if not contact_initialized:
|
||
state = 0
|
||
elif not post_refined_flag:
|
||
state = 1
|
||
else:
|
||
state = 2
|
||
|
||
return (cop_x, cop_y,
|
||
0, rows-1, 0, cols-1,
|
||
delta_CoP_x, delta_CoP_y,
|
||
base_x, base_y,
|
||
magnitude, state,
|
||
total_pressure, dynamic_thresh)
|
||
|
||
# ===================== 角度计算核心 =====================
|
||
def _legacy_compute_vector_angle(x: float, y: float) -> tuple[float, float]:
|
||
epsilon = 1e-8
|
||
mag = np.hypot(x, y)
|
||
angle = np.degrees(np.arctan2(y, x + epsilon))
|
||
if angle < 0:
|
||
angle += 360
|
||
return angle, mag
|
||
|
||
def _legacy_compute_PZT_angle(Px: float, Py: float) -> tuple[float, float]:
|
||
return _legacy_compute_vector_angle(Px, Py)
|
||
|
||
# ===================== 核心入口函数 =====================
|
||
def _legacy_get_pzt_angle(adc_data):
|
||
if len(adc_data) != 84:
|
||
raise ValueError("ADC数据长度必须为84")
|
||
result = _legacy_compute_pressure_direction(adc_data)
|
||
cop_x, cop_y = result[0], result[1]
|
||
dx, dy = result[6], result[7]
|
||
base_x, base_y = result[8], result[9]
|
||
magnitude = result[10]
|
||
state = int(result[11])
|
||
total_press = result[12]
|
||
threshold = result[13]
|
||
pzt_angle, _ = _legacy_compute_PZT_angle(dx, dy)
|
||
return pzt_angle, magnitude, state, cop_x, cop_y, base_x, base_y, total_press, threshold
|
||
|
||
# ===================== 重置基线(校准用) =====================
|
||
def _legacy_reset_baseline():
|
||
global first_frame, noise_sum_buf, dynamic_thresh
|
||
with first_frame_lock:
|
||
first_frame = None
|
||
noise_sum_buf.clear()
|
||
dynamic_thresh = None
|
||
_legacy_reset_cop_state()
|
||
|
||
|
||
from dataclasses import dataclass
|
||
from enum import IntEnum
|
||
from typing import Optional, Tuple
|
||
|
||
|
||
ADC_LEN = SENSOR_ROWS * SENSOR_COLS
|
||
|
||
|
||
class CoPState(IntEnum):
|
||
NO_CONTACT = 0
|
||
INIT_COLLECTING = 1
|
||
POST_REFINING = 2
|
||
READY = 3
|
||
|
||
|
||
@dataclass
|
||
class CoPResult:
|
||
cop_x: float
|
||
cop_y: float
|
||
row_min: int
|
||
row_max: int
|
||
col_min: int
|
||
col_max: int
|
||
dx: float
|
||
dy: float
|
||
base_x: float
|
||
base_y: float
|
||
magnitude: float
|
||
state: int
|
||
total_pressure: float
|
||
threshold: Optional[float]
|
||
angle: float
|
||
|
||
|
||
@dataclass
|
||
class CoPConfig:
|
||
rows: int = SENSOR_ROWS
|
||
cols: int = SENSOR_COLS
|
||
noise_collect_ms: float = 300.0
|
||
thresh_k: float = 5.0
|
||
min_threshold: float = 50.0
|
||
contact_confirm_ms: float = 20.0
|
||
release_confirm_ms: float = 50.0
|
||
init_collect_ms: float = 80.0
|
||
snap_enable: bool = True
|
||
snap_center_x: float = 3.0
|
||
snap_center_y: float = 5.5
|
||
snap_range_x: float = 0.25
|
||
snap_range_y: float = 0.25
|
||
post_refine_enable: bool = True
|
||
post_refine_window_ms: float = 800.0
|
||
post_stable_ms: float = 200.0
|
||
post_stable_thresh: float = 0.1
|
||
cop_lpf_alpha: float = 0.25
|
||
epsilon: float = 1e-8
|
||
|
||
|
||
class PressureDirectionEstimator:
|
||
def __init__(self, config: CoPConfig = CoPConfig()):
|
||
self.cfg = config
|
||
self.reset_all()
|
||
|
||
def reset_all(self):
|
||
self.dynamic_thresh: Optional[float] = None
|
||
self.noise_samples = []
|
||
self.noise_start_ms: Optional[float] = None
|
||
self.reset_contact_state()
|
||
|
||
def reset_contact_state(self):
|
||
self.first_contact_cop_x: Optional[float] = None
|
||
self.first_contact_cop_y: Optional[float] = None
|
||
self.state = CoPState.NO_CONTACT
|
||
self.init_x_buf = []
|
||
self.init_y_buf = []
|
||
self.init_start_ms: Optional[float] = None
|
||
self.post_start_ms: Optional[float] = None
|
||
self.post_stable_start_ms: Optional[float] = None
|
||
self.post_cand_x: Optional[float] = None
|
||
self.post_cand_y: Optional[float] = None
|
||
self.post_refined = False
|
||
self.contact_candidate_start_ms: Optional[float] = None
|
||
self.release_candidate_start_ms: Optional[float] = None
|
||
self.filtered_cop_x: Optional[float] = None
|
||
self.filtered_cop_y: Optional[float] = None
|
||
|
||
def update(self, adc_data, timestamp_ms: float) -> CoPResult:
|
||
frame2d = self._prepare_frame(adc_data)
|
||
total_pressure = float(np.sum(frame2d))
|
||
|
||
self._update_dynamic_threshold(total_pressure, timestamp_ms)
|
||
|
||
raw_contact = self._is_raw_contact(total_pressure)
|
||
contact_valid = self._debounce_contact(raw_contact, timestamp_ms)
|
||
|
||
if not contact_valid:
|
||
self._handle_no_contact()
|
||
return self._make_empty_result(total_pressure)
|
||
|
||
cop_x, cop_y = self._compute_cop(frame2d, total_pressure)
|
||
cop_x, cop_y = self._filter_cop(cop_x, cop_y)
|
||
|
||
self._update_state_machine(cop_x, cop_y, timestamp_ms)
|
||
|
||
if self.first_contact_cop_x is None or self.first_contact_cop_y is None:
|
||
dx = 0.0
|
||
dy = 0.0
|
||
base_x = cop_x
|
||
base_y = cop_y
|
||
else:
|
||
base_x = self.first_contact_cop_x
|
||
base_y = self.first_contact_cop_y
|
||
dx = cop_x - base_x
|
||
dy = base_y - cop_y
|
||
|
||
magnitude = float(np.hypot(dx, dy))
|
||
angle = self.compute_vector_angle(dx, dy)[0]
|
||
|
||
return CoPResult(
|
||
cop_x=float(cop_x),
|
||
cop_y=float(cop_y),
|
||
row_min=0,
|
||
row_max=self.cfg.rows - 1,
|
||
col_min=0,
|
||
col_max=self.cfg.cols - 1,
|
||
dx=float(dx),
|
||
dy=float(dy),
|
||
base_x=float(base_x),
|
||
base_y=float(base_y),
|
||
magnitude=magnitude,
|
||
state=int(self.state),
|
||
total_pressure=total_pressure,
|
||
threshold=self.dynamic_thresh,
|
||
angle=float(angle),
|
||
)
|
||
|
||
def _prepare_frame(self, adc_data) -> np.ndarray:
|
||
arr = np.asarray(adc_data, dtype=np.float32).flatten()
|
||
expected_len = self.cfg.rows * self.cfg.cols
|
||
|
||
if len(arr) != expected_len:
|
||
raise ValueError(f"ADC数据长度必须为{expected_len},当前为{len(arr)}")
|
||
|
||
arr = np.clip(arr, 0, None)
|
||
return arr.reshape(self.cfg.rows, self.cfg.cols)
|
||
|
||
def _update_dynamic_threshold(self, total_pressure: float, timestamp_ms: float):
|
||
if self.dynamic_thresh is not None:
|
||
return
|
||
|
||
if self.noise_start_ms is None:
|
||
self.noise_start_ms = timestamp_ms
|
||
|
||
self.noise_samples.append(total_pressure)
|
||
|
||
if timestamp_ms - self.noise_start_ms >= self.cfg.noise_collect_ms:
|
||
samples = np.asarray(self.noise_samples, dtype=np.float32)
|
||
mean_val = float(np.mean(samples))
|
||
std_val = float(np.std(samples))
|
||
thresh = mean_val + self.cfg.thresh_k * std_val
|
||
self.dynamic_thresh = max(thresh, self.cfg.min_threshold)
|
||
|
||
def _is_raw_contact(self, total_pressure: float) -> bool:
|
||
if self.dynamic_thresh is None:
|
||
return False
|
||
return total_pressure >= self.dynamic_thresh
|
||
|
||
def _debounce_contact(self, raw_contact: bool, timestamp_ms: float) -> bool:
|
||
currently_in_contact = self.state != CoPState.NO_CONTACT
|
||
|
||
if raw_contact:
|
||
self.release_candidate_start_ms = None
|
||
|
||
if currently_in_contact:
|
||
return True
|
||
|
||
if self.contact_candidate_start_ms is None:
|
||
self.contact_candidate_start_ms = timestamp_ms
|
||
|
||
return timestamp_ms - self.contact_candidate_start_ms >= self.cfg.contact_confirm_ms
|
||
|
||
self.contact_candidate_start_ms = None
|
||
|
||
if not currently_in_contact:
|
||
return False
|
||
|
||
if self.release_candidate_start_ms is None:
|
||
self.release_candidate_start_ms = timestamp_ms
|
||
|
||
if timestamp_ms - self.release_candidate_start_ms >= self.cfg.release_confirm_ms:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _handle_no_contact(self):
|
||
if self.state != CoPState.NO_CONTACT:
|
||
self.reset_contact_state()
|
||
|
||
def _compute_cop(self, frame2d: np.ndarray, total_pressure: float) -> Tuple[float, float]:
|
||
rows, cols = self.cfg.rows, self.cfg.cols
|
||
x_grid = np.tile(np.arange(cols, dtype=np.float32), (rows, 1))
|
||
y_grid = np.repeat(np.arange(rows, dtype=np.float32), cols).reshape(rows, cols)
|
||
cop_x = float(np.sum(frame2d * x_grid) / total_pressure)
|
||
cop_y = float(np.sum(frame2d * y_grid) / total_pressure)
|
||
return cop_x, cop_y
|
||
|
||
def _filter_cop(self, cop_x: float, cop_y: float) -> Tuple[float, float]:
|
||
alpha = self.cfg.cop_lpf_alpha
|
||
|
||
if alpha <= 0.0:
|
||
return cop_x, cop_y
|
||
|
||
if self.filtered_cop_x is None or self.filtered_cop_y is None:
|
||
self.filtered_cop_x = cop_x
|
||
self.filtered_cop_y = cop_y
|
||
else:
|
||
self.filtered_cop_x = alpha * cop_x + (1.0 - alpha) * self.filtered_cop_x
|
||
self.filtered_cop_y = alpha * cop_y + (1.0 - alpha) * self.filtered_cop_y
|
||
|
||
return self.filtered_cop_x, self.filtered_cop_y
|
||
|
||
def _update_state_machine(self, cop_x: float, cop_y: float, timestamp_ms: float):
|
||
if self.state == CoPState.NO_CONTACT:
|
||
self.state = CoPState.INIT_COLLECTING
|
||
self.init_start_ms = timestamp_ms
|
||
self.init_x_buf.clear()
|
||
self.init_y_buf.clear()
|
||
|
||
if self.state == CoPState.INIT_COLLECTING:
|
||
self.init_x_buf.append(cop_x)
|
||
self.init_y_buf.append(cop_y)
|
||
|
||
if self.init_start_ms is None:
|
||
self.init_start_ms = timestamp_ms
|
||
|
||
if timestamp_ms - self.init_start_ms >= self.cfg.init_collect_ms:
|
||
base_x = float(np.median(self.init_x_buf))
|
||
base_y = float(np.median(self.init_y_buf))
|
||
base_x, base_y = self._apply_center_snap(base_x, base_y)
|
||
|
||
self.first_contact_cop_x = base_x
|
||
self.first_contact_cop_y = base_y
|
||
self.post_start_ms = timestamp_ms
|
||
self.post_cand_x = None
|
||
self.post_cand_y = None
|
||
self.post_stable_start_ms = None
|
||
|
||
if self.cfg.post_refine_enable:
|
||
self.state = CoPState.POST_REFINING
|
||
else:
|
||
self.post_refined = True
|
||
self.state = CoPState.READY
|
||
|
||
return
|
||
|
||
if self.state == CoPState.POST_REFINING:
|
||
self._post_refine(cop_x, cop_y, timestamp_ms)
|
||
|
||
def _apply_center_snap(self, base_x: float, base_y: float) -> Tuple[float, float]:
|
||
if not self.cfg.snap_enable:
|
||
return base_x, base_y
|
||
|
||
if (
|
||
abs(base_x - self.cfg.snap_center_x) <= self.cfg.snap_range_x
|
||
and abs(base_y - self.cfg.snap_center_y) <= self.cfg.snap_range_y
|
||
):
|
||
return self.cfg.snap_center_x, self.cfg.snap_center_y
|
||
|
||
return base_x, base_y
|
||
|
||
def _post_refine(self, cop_x: float, cop_y: float, timestamp_ms: float):
|
||
if self.post_start_ms is None:
|
||
self.post_start_ms = timestamp_ms
|
||
|
||
if timestamp_ms - self.post_start_ms >= self.cfg.post_refine_window_ms:
|
||
self.post_refined = True
|
||
self.state = CoPState.READY
|
||
return
|
||
|
||
if self.post_cand_x is None or self.post_cand_y is None:
|
||
self.post_cand_x = cop_x
|
||
self.post_cand_y = cop_y
|
||
self.post_stable_start_ms = timestamp_ms
|
||
return
|
||
|
||
dist = float(np.hypot(cop_x - self.post_cand_x, cop_y - self.post_cand_y))
|
||
|
||
if dist <= self.cfg.post_stable_thresh:
|
||
if self.post_stable_start_ms is None:
|
||
self.post_stable_start_ms = timestamp_ms
|
||
|
||
if timestamp_ms - self.post_stable_start_ms >= self.cfg.post_stable_ms:
|
||
refined_x, refined_y = self._apply_center_snap(self.post_cand_x, self.post_cand_y)
|
||
self.first_contact_cop_x = float(refined_x)
|
||
self.first_contact_cop_y = float(refined_y)
|
||
self.post_refined = True
|
||
self.state = CoPState.READY
|
||
else:
|
||
self.post_cand_x = cop_x
|
||
self.post_cand_y = cop_y
|
||
self.post_stable_start_ms = timestamp_ms
|
||
|
||
def _make_empty_result(self, total_pressure: float) -> CoPResult:
|
||
return CoPResult(
|
||
cop_x=0.0,
|
||
cop_y=0.0,
|
||
row_min=0,
|
||
row_max=self.cfg.rows - 1,
|
||
col_min=0,
|
||
col_max=self.cfg.cols - 1,
|
||
dx=0.0,
|
||
dy=0.0,
|
||
base_x=0.0,
|
||
base_y=0.0,
|
||
magnitude=0.0,
|
||
state=int(CoPState.NO_CONTACT),
|
||
total_pressure=float(total_pressure),
|
||
threshold=self.dynamic_thresh,
|
||
angle=0.0,
|
||
)
|
||
|
||
def compute_vector_angle(self, x: float, y: float) -> Tuple[float, float]:
|
||
mag = float(np.hypot(x, y))
|
||
angle = float(np.degrees(np.arctan2(y, x + self.cfg.epsilon)))
|
||
|
||
if angle < 0:
|
||
angle += 360.0
|
||
|
||
return angle, mag
|
||
|
||
@dataclass
|
||
class LocalForceResult:
|
||
angle: float
|
||
magnitude: float
|
||
planar_x: float
|
||
planar_y: float
|
||
confidence: float
|
||
contact_active: bool
|
||
reportable: bool
|
||
total_pressure: float
|
||
peak: float
|
||
cop_x: float
|
||
cop_y: float
|
||
threshold: float
|
||
|
||
|
||
class LocalTangentialForceEstimator:
|
||
CONTACT_ENTER_TOTAL_THRESHOLD = 520.0
|
||
CONTACT_ENTER_PEAK_THRESHOLD = 50.0
|
||
CONTACT_EXIT_TOTAL_THRESHOLD = 260.0
|
||
CONTACT_EXIT_PEAK_THRESHOLD = 28.0
|
||
CONTACT_ENTER_FRAMES_REQUIRED = 2
|
||
CONTACT_EXIT_FRAMES_REQUIRED = 8
|
||
|
||
BASELINE_IDLE_ALPHA = 0.035
|
||
BASELINE_BOOTSTRAP_ALPHA = 1.0
|
||
BASELINE_NOISE_FLOOR = 5.0
|
||
|
||
ACTIVE_CELL_MIN_VALUE = 18.0
|
||
ACTIVE_CELL_PEAK_RATIO = 0.14
|
||
MIN_ACTIVE_CELLS = 3
|
||
|
||
VECTOR_SMOOTHING_ALPHA = 0.16
|
||
REPORT_MAGNITUDE_ENTER = 0.12
|
||
REPORT_MAGNITUDE_EXIT = 0.045
|
||
REPORT_CONFIDENCE_ENTER = 0.14
|
||
REPORT_CONFIDENCE_EXIT = 0.06
|
||
REPORT_HOLD_FRAMES = 10
|
||
|
||
ASYMMETRY_WEIGHT = 1.1
|
||
DRIFT_WEIGHT = 0.65
|
||
MOTION_WEIGHT = 0.25
|
||
EDGE_ASYMMETRY_DAMPING = 0.35
|
||
EDGE_INWARD_ROLLING_BIAS = 0.55
|
||
EDGE_START_COP_THRESHOLD = 0.45
|
||
EDGE_START_BIAS_WEIGHT = 1.1
|
||
ROLLING_FRICTION_ALPHA = 0.68
|
||
ROLLING_FRICTION_MIN_MAGNITUDE = 0.05
|
||
|
||
def __init__(self):
|
||
self.reset_all()
|
||
|
||
def reset_all(self):
|
||
self.baseline_frame = None
|
||
self.reset_contact_state()
|
||
|
||
def reset_contact_state(self):
|
||
self.contact_active = False
|
||
self.contact_enter_counter = 0
|
||
self.contact_exit_counter = 0
|
||
self.anchor_cop_x = None
|
||
self.anchor_cop_y = None
|
||
self.last_cop_x = None
|
||
self.last_cop_y = None
|
||
self.edge_start_bias_x = 0.0
|
||
self.edge_start_bias_y = 0.0
|
||
self.smoothed_x = 0.0
|
||
self.smoothed_y = 0.0
|
||
self.report_active = False
|
||
self.report_hold_counter = 0
|
||
self.held_report = None
|
||
|
||
def update(self, adc_data, timestamp_ms: float) -> LocalForceResult:
|
||
raw = np.asarray(adc_data, dtype=np.float32).flatten()
|
||
if len(raw) != ADC_LEN:
|
||
raise ValueError(f"ADC data length must be {ADC_LEN}")
|
||
|
||
baseline_subtracted = self._subtract_baseline(raw)
|
||
if not self._update_contact_state(raw, baseline_subtracted):
|
||
return self._inactive_result(float(np.sum(baseline_subtracted)), float(np.max(baseline_subtracted, initial=0.0)))
|
||
|
||
stats = self._compute_contact_stats(baseline_subtracted)
|
||
if stats is None:
|
||
return self._stabilize_report(self._weak_contact_result(float(np.sum(baseline_subtracted)), float(np.max(baseline_subtracted, initial=0.0))))
|
||
|
||
if self.anchor_cop_x is None:
|
||
self.anchor_cop_x = stats["cop_x"]
|
||
self.anchor_cop_y = stats["cop_y"]
|
||
self.last_cop_x = stats["cop_x"]
|
||
self.last_cop_y = stats["cop_y"]
|
||
self.edge_start_bias_x, self.edge_start_bias_y = self._edge_start_bias(stats)
|
||
return self._stabilize_report(self._weak_contact_result(stats["total"], stats["peak"], stats["cop_x"], stats["cop_y"]))
|
||
|
||
anchor_x = self.anchor_cop_x
|
||
anchor_y = self.anchor_cop_y if self.anchor_cop_y is not None else stats["cop_y"]
|
||
last_x = self.last_cop_x if self.last_cop_x is not None else stats["cop_x"]
|
||
last_y = self.last_cop_y if self.last_cop_y is not None else stats["cop_y"]
|
||
|
||
drift_x = stats["cop_x"] - anchor_x
|
||
drift_y = stats["cop_y"] - anchor_y
|
||
motion_x = stats["cop_x"] - last_x
|
||
motion_y = stats["cop_y"] - last_y
|
||
|
||
kinematic_x = drift_x * self.DRIFT_WEIGHT + motion_x * self.MOTION_WEIGHT
|
||
kinematic_y = drift_y * self.DRIFT_WEIGHT + motion_y * self.MOTION_WEIGHT
|
||
asymmetry_x, asymmetry_y = self._damp_edge_asymmetry(
|
||
stats,
|
||
kinematic_x + self.edge_start_bias_x,
|
||
kinematic_y + self.edge_start_bias_y,
|
||
)
|
||
|
||
combined_x = asymmetry_x + kinematic_x + self.edge_start_bias_x
|
||
combined_y = asymmetry_y + kinematic_y + self.edge_start_bias_y
|
||
combined_x, combined_y = self._apply_rolling_friction(
|
||
self.smoothed_x,
|
||
self.smoothed_y,
|
||
combined_x,
|
||
combined_y,
|
||
)
|
||
|
||
self.smoothed_x += (combined_x - self.smoothed_x) * self.VECTOR_SMOOTHING_ALPHA
|
||
self.smoothed_y += (combined_y - self.smoothed_y) * self.VECTOR_SMOOTHING_ALPHA
|
||
self.last_cop_x = stats["cop_x"]
|
||
self.last_cop_y = stats["cop_y"]
|
||
|
||
planar_x = self.smoothed_x
|
||
planar_y = -self.smoothed_y
|
||
angle, magnitude = self.compute_vector_angle(planar_x, planar_y)
|
||
|
||
active_span_rows = (stats["max_row"] - stats["min_row"] + 1) / SENSOR_ROWS
|
||
active_span_cols = (stats["max_col"] - stats["min_col"] + 1) / SENSOR_COLS
|
||
activity = min(max(stats["active_cells"] / ADC_LEN, 0.0), 1.0)
|
||
span = min(max((active_span_rows + active_span_cols) * 0.5, 0.0), 1.0)
|
||
pressure_ratio = min(max(stats["active_total"] / max(stats["total"], 1.0), 0.0), 1.0)
|
||
peak_ratio = min(max(stats["peak"] / (stats["total"] / stats["active_cells"] + 1.0), 0.0), 1.0)
|
||
confidence = min(max(activity * 0.35 + span * 0.2 + pressure_ratio * 0.3 + peak_ratio * 0.15, 0.0), 1.0)
|
||
|
||
return self._stabilize_report(LocalForceResult(
|
||
angle=angle,
|
||
magnitude=magnitude,
|
||
planar_x=planar_x,
|
||
planar_y=planar_y,
|
||
confidence=confidence,
|
||
contact_active=True,
|
||
reportable=False,
|
||
total_pressure=stats["total"],
|
||
peak=stats["peak"],
|
||
cop_x=stats["cop_x"],
|
||
cop_y=stats["cop_y"],
|
||
threshold=self.CONTACT_ENTER_TOTAL_THRESHOLD,
|
||
))
|
||
|
||
def _update_idle_baseline(self, raw_frame, alpha: float):
|
||
if self.baseline_frame is None:
|
||
self.baseline_frame = np.array(raw_frame, dtype=np.float32).copy()
|
||
return
|
||
self.baseline_frame += (raw_frame - self.baseline_frame) * alpha
|
||
|
||
def _subtract_baseline(self, raw_frame):
|
||
if self.baseline_frame is None:
|
||
self._update_idle_baseline(raw_frame, self.BASELINE_BOOTSTRAP_ALPHA)
|
||
diff = raw_frame - self.baseline_frame - self.BASELINE_NOISE_FLOOR
|
||
return np.clip(diff, 0, None)
|
||
|
||
def _pressure_metrics(self, frame):
|
||
return float(np.sum(frame)), float(np.max(frame, initial=0.0))
|
||
|
||
def _update_contact_state(self, raw_frame, frame) -> bool:
|
||
total, peak = self._pressure_metrics(frame)
|
||
enter = total >= self.CONTACT_ENTER_TOTAL_THRESHOLD and peak >= self.CONTACT_ENTER_PEAK_THRESHOLD
|
||
exit_frame = total <= self.CONTACT_EXIT_TOTAL_THRESHOLD or peak <= self.CONTACT_EXIT_PEAK_THRESHOLD
|
||
|
||
if self.contact_active:
|
||
if exit_frame:
|
||
self.contact_exit_counter += 1
|
||
if self.contact_exit_counter >= self.CONTACT_EXIT_FRAMES_REQUIRED:
|
||
self._update_idle_baseline(raw_frame, self.BASELINE_IDLE_ALPHA)
|
||
self.reset_contact_state()
|
||
return False
|
||
else:
|
||
self.contact_exit_counter = 0
|
||
return True
|
||
|
||
if enter:
|
||
self.contact_enter_counter += 1
|
||
if self.contact_enter_counter >= self.CONTACT_ENTER_FRAMES_REQUIRED:
|
||
self.contact_active = True
|
||
self.contact_enter_counter = 0
|
||
self.contact_exit_counter = 0
|
||
return True
|
||
return False
|
||
|
||
self.contact_enter_counter = 0
|
||
self._update_idle_baseline(raw_frame, self.BASELINE_IDLE_ALPHA)
|
||
return False
|
||
|
||
def _compute_contact_stats(self, frame):
|
||
total, peak = self._pressure_metrics(frame)
|
||
if total <= 0.0 or peak <= 0.0:
|
||
return None
|
||
|
||
active_threshold = max(peak * self.ACTIVE_CELL_PEAK_RATIO, self.ACTIVE_CELL_MIN_VALUE)
|
||
frame2d = np.asarray(frame, dtype=np.float32).reshape(SENSOR_ROWS, SENSOR_COLS)
|
||
active_mask = frame2d >= active_threshold
|
||
active_cells = int(np.count_nonzero(active_mask))
|
||
if active_cells < self.MIN_ACTIVE_CELLS:
|
||
return None
|
||
|
||
active_values = frame2d[active_mask]
|
||
active_total = float(np.sum(active_values))
|
||
if active_total <= 0.0:
|
||
return None
|
||
|
||
rows, cols = np.nonzero(active_mask)
|
||
cop_x = float(np.sum(active_values * cols) / active_total)
|
||
cop_y = float(np.sum(active_values * rows) / active_total)
|
||
min_row, max_row = int(np.min(rows)), int(np.max(rows))
|
||
min_col, max_col = int(np.min(cols)), int(np.max(cols))
|
||
bbox_center_x = (min_col + max_col) * 0.5
|
||
bbox_center_y = (min_row + max_row) * 0.5
|
||
half_width = max(max_col - min_col, 1) * 0.5
|
||
half_height = max(max_row - min_row, 1) * 0.5
|
||
asymmetry_x = float(np.sum(active_values * ((cols - bbox_center_x) / half_width)) / active_total)
|
||
asymmetry_y = float(np.sum(active_values * ((rows - bbox_center_y) / half_height)) / active_total)
|
||
|
||
return {
|
||
"total": total,
|
||
"peak": peak,
|
||
"active_total": active_total,
|
||
"active_cells": active_cells,
|
||
"min_row": min_row,
|
||
"max_row": max_row,
|
||
"min_col": min_col,
|
||
"max_col": max_col,
|
||
"cop_x": cop_x,
|
||
"cop_y": cop_y,
|
||
"asymmetry_x": asymmetry_x,
|
||
"asymmetry_y": asymmetry_y,
|
||
}
|
||
|
||
def _contact_touches_edge(self, stats) -> bool:
|
||
return (
|
||
stats["min_row"] == 0
|
||
or stats["max_row"] == SENSOR_ROWS - 1
|
||
or stats["min_col"] == 0
|
||
or stats["max_col"] == SENSOR_COLS - 1
|
||
)
|
||
|
||
def _damp_edge_asymmetry(self, stats, kinematic_x: float, kinematic_y: float):
|
||
asymmetry_x = stats["asymmetry_x"] * self.ASYMMETRY_WEIGHT
|
||
asymmetry_y = stats["asymmetry_y"] * self.ASYMMETRY_WEIGHT
|
||
|
||
if stats["min_col"] == 0 and asymmetry_x < 0.0:
|
||
asymmetry_x = -asymmetry_x * self.EDGE_INWARD_ROLLING_BIAS
|
||
if stats["max_col"] == SENSOR_COLS - 1 and asymmetry_x > 0.0:
|
||
asymmetry_x = -asymmetry_x * self.EDGE_INWARD_ROLLING_BIAS
|
||
if stats["min_row"] == 0 and asymmetry_y < 0.0:
|
||
asymmetry_y = -asymmetry_y * self.EDGE_INWARD_ROLLING_BIAS
|
||
if stats["max_row"] == SENSOR_ROWS - 1 and asymmetry_y > 0.0:
|
||
asymmetry_y = -asymmetry_y * self.EDGE_INWARD_ROLLING_BIAS
|
||
|
||
opposing_dot = asymmetry_x * kinematic_x + asymmetry_y * kinematic_y
|
||
kinematic_mag = float(np.hypot(kinematic_x, kinematic_y))
|
||
if self._contact_touches_edge(stats) and opposing_dot < 0.0 and kinematic_mag >= self.ROLLING_FRICTION_MIN_MAGNITUDE:
|
||
asymmetry_x *= self.EDGE_ASYMMETRY_DAMPING
|
||
asymmetry_y *= self.EDGE_ASYMMETRY_DAMPING
|
||
|
||
return asymmetry_x, asymmetry_y
|
||
|
||
def _edge_start_bias(self, stats):
|
||
center_x = (SENSOR_COLS - 1) * 0.5
|
||
center_y = (SENSOR_ROWS - 1) * 0.5
|
||
normalized_x = min(max((stats["cop_x"] - center_x) / max(center_x, 1.0), -1.0), 1.0)
|
||
normalized_y = min(max((stats["cop_y"] - center_y) / max(center_y, 1.0), -1.0), 1.0)
|
||
bias_x = self._edge_start_axis_bias(normalized_x) if stats["min_col"] == 0 or stats["max_col"] == SENSOR_COLS - 1 else 0.0
|
||
bias_y = self._edge_start_axis_bias(normalized_y) if stats["min_row"] == 0 or stats["max_row"] == SENSOR_ROWS - 1 else 0.0
|
||
return bias_x, bias_y
|
||
|
||
def _edge_start_axis_bias(self, normalized_axis: float) -> float:
|
||
distance = abs(normalized_axis)
|
||
if distance <= self.EDGE_START_COP_THRESHOLD:
|
||
return 0.0
|
||
strength = min(max((distance - self.EDGE_START_COP_THRESHOLD) / (1.0 - self.EDGE_START_COP_THRESHOLD), 0.0), 1.0)
|
||
return -np.sign(normalized_axis) * strength * self.EDGE_START_BIAS_WEIGHT
|
||
|
||
def _apply_rolling_friction(self, previous_x: float, previous_y: float, current_x: float, current_y: float):
|
||
previous_mag = float(np.hypot(previous_x, previous_y))
|
||
current_mag = float(np.hypot(current_x, current_y))
|
||
if previous_mag < self.ROLLING_FRICTION_MIN_MAGNITUDE or current_mag < self.ROLLING_FRICTION_MIN_MAGNITUDE:
|
||
return current_x, current_y
|
||
|
||
dot = previous_x * current_x + previous_y * current_y
|
||
if dot >= 0.0:
|
||
return current_x, current_y
|
||
|
||
mixed_x = current_x * (1.0 - self.ROLLING_FRICTION_ALPHA) + previous_x * self.ROLLING_FRICTION_ALPHA
|
||
mixed_y = current_y * (1.0 - self.ROLLING_FRICTION_ALPHA) + previous_y * self.ROLLING_FRICTION_ALPHA
|
||
if mixed_x * previous_x + mixed_y * previous_y >= 0.0:
|
||
return mixed_x, mixed_y
|
||
|
||
keep_mag = min(previous_mag, current_mag) * 0.5
|
||
return previous_x / previous_mag * keep_mag, previous_y / previous_mag * keep_mag
|
||
|
||
def _inactive_result(self, total_pressure=0.0, peak=0.0):
|
||
return LocalForceResult(0.0, 0.0, 0.0, 0.0, 0.0, False, False, total_pressure, peak, 0.0, 0.0, self.CONTACT_ENTER_TOTAL_THRESHOLD)
|
||
|
||
def _weak_contact_result(self, total_pressure=0.0, peak=0.0, cop_x=0.0, cop_y=0.0):
|
||
return LocalForceResult(0.0, 0.0, 0.0, 0.0, 0.0, True, False, total_pressure, peak, cop_x, cop_y, self.CONTACT_ENTER_TOTAL_THRESHOLD)
|
||
|
||
def _store_report(self, result: LocalForceResult):
|
||
result.reportable = True
|
||
self.report_active = True
|
||
self.report_hold_counter = 0
|
||
self.held_report = result
|
||
return result
|
||
|
||
def _hold_or_drop_report(self):
|
||
if self.report_active and self.report_hold_counter < self.REPORT_HOLD_FRAMES and self.held_report is not None:
|
||
self.report_hold_counter += 1
|
||
held = self.held_report
|
||
held.reportable = True
|
||
return held
|
||
self.report_active = False
|
||
self.report_hold_counter = 0
|
||
self.held_report = None
|
||
return self._weak_contact_result()
|
||
|
||
def _stabilize_report(self, result: LocalForceResult):
|
||
if not result.contact_active:
|
||
self.report_active = False
|
||
self.report_hold_counter = 0
|
||
self.held_report = None
|
||
return result
|
||
|
||
can_enter = result.magnitude >= self.REPORT_MAGNITUDE_ENTER and result.confidence >= self.REPORT_CONFIDENCE_ENTER
|
||
can_stay = result.magnitude >= self.REPORT_MAGNITUDE_EXIT and result.confidence >= self.REPORT_CONFIDENCE_EXIT
|
||
if self.report_active:
|
||
if can_stay:
|
||
return self._store_report(result)
|
||
return self._hold_or_drop_report()
|
||
if can_enter:
|
||
return self._store_report(result)
|
||
return result
|
||
|
||
def compute_vector_angle(self, x: float, y: float) -> Tuple[float, float]:
|
||
magnitude = float(np.hypot(x, y))
|
||
if magnitude <= np.finfo(np.float32).eps:
|
||
return 0.0, 0.0
|
||
angle = float(np.degrees(np.arctan2(y, x)))
|
||
if angle < 0.0:
|
||
angle += 360.0
|
||
return angle, magnitude
|
||
|
||
|
||
_estimator = LocalTangentialForceEstimator()
|
||
|
||
|
||
def reset_cop_state():
|
||
_estimator.reset_contact_state()
|
||
|
||
|
||
def reset_all_state():
|
||
_estimator.reset_all()
|
||
|
||
|
||
def compute_pressure_direction(adc_data, timestamp_ms: float):
|
||
result = _estimator.update(adc_data, timestamp_ms)
|
||
|
||
return (
|
||
result.cop_x,
|
||
result.cop_y,
|
||
0,
|
||
SENSOR_ROWS - 1,
|
||
0,
|
||
SENSOR_COLS - 1,
|
||
result.planar_x,
|
||
result.planar_y,
|
||
0.0,
|
||
0.0,
|
||
result.magnitude,
|
||
1 if result.reportable else 0,
|
||
result.total_pressure,
|
||
result.threshold,
|
||
)
|
||
|
||
|
||
def compute_vector_angle(x: float, y: float) -> Tuple[float, float]:
|
||
return _estimator.compute_vector_angle(x, y)
|
||
|
||
|
||
def compute_PZT_angle(Px: float, Py: float) -> Tuple[float, float]:
|
||
return compute_vector_angle(Px, Py)
|
||
|
||
|
||
def get_pzt_angle(adc_data, timestamp_ms: float):
|
||
if len(adc_data) != ADC_LEN:
|
||
raise ValueError(f"ADC数据长度必须为{ADC_LEN}")
|
||
|
||
result = _estimator.update(adc_data, timestamp_ms)
|
||
|
||
return (
|
||
result.angle,
|
||
result.magnitude,
|
||
1 if result.reportable else 0,
|
||
result.cop_x,
|
||
result.cop_y,
|
||
0.0,
|
||
0.0,
|
||
result.total_pressure,
|
||
result.threshold,
|
||
)
|
||
|
||
|
||
def reset_baseline():
|
||
reset_all_state()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="JE-Skin DevKit gRPC Server")
|
||
parser.add_argument("--port", type=int, default=50051, help="gRPC listen port (default: 50051)")
|
||
args = parser.parse_args()
|
||
serve(args.port)
|