diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index a4ab21b..6898729 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -14,6 +14,8 @@ dependencies = [ "fern", "humantime", "log", + "ndarray", + "ort", "serde", "serde_json", "tauri", @@ -247,6 +249,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + [[package]] name = "bit-set" version = "0.8.0" @@ -556,7 +564,7 @@ dependencies = [ "bitflags 2.11.0", "core-foundation", "core-graphics-types", - "foreign-types", + "foreign-types 0.5.0", "libc", ] @@ -734,6 +742,16 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "der" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.5.8" @@ -1071,6 +1089,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] + [[package]] name = "foreign-types" version = "0.5.0" @@ -1078,7 +1105,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared", + "foreign-types-shared 0.3.1", ] [[package]] @@ -1092,6 +1119,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1576,6 +1609,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac-sha256" +version = "1.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec9d92d097f4749b64e8cc33d924d9f40a2d4eb91402b458014b781f5733d60f" + [[package]] name = "html5ever" version = "0.29.1" @@ -2106,6 +2145,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lzma-rust2" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69" + [[package]] name = "mac" version = "0.1.1" @@ -2163,6 +2208,16 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.8.0" @@ -2240,6 +2295,38 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk" version = "0.9.0" @@ -2305,12 +2392,30 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2483,6 +2588,50 @@ dependencies = [ "pathdiff", ] +[[package]] +name = "openssl" +version = "0.10.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -2499,6 +2648,30 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "ort" +version = "2.0.0-rc.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7de3af33d24a745ffb8fab904b13478438d1cd52868e6f17735ef6e1f8bf133" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec", + "tracing", + "ureq", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7b497d21a8b6fbb4b5a544f8fadb77e801a09ae0add9e411d31c6f89e3c1e90" +dependencies = [ + "hmac-sha256", + "lzma-rust2", + "ureq", +] + [[package]] name = "pango" version = "0.18.3" @@ -2559,6 +2732,15 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" +[[package]] +name = "pem-rfc7468" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6305423e0e7738146434843d1694d621cce767262b2a86910beab705e4493d9" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2821,6 +3003,21 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -3046,6 +3243,12 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.5.18" @@ -3177,6 +3380,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -3198,6 +3410,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "schemars" version = "0.8.22" @@ -3255,6 +3476,29 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags 2.11.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "selectors" version = "0.24.0" @@ -3554,6 +3798,17 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "softbuffer" version = "0.4.8" @@ -4507,6 +4762,36 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf8-zero", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -4538,6 +4823,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -4556,6 +4847,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version-compare" version = "0.2.1" @@ -4809,6 +5106,15 @@ dependencies = [ "system-deps", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webview2-com" version = "0.38.2" @@ -5600,6 +5906,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zerotrie" version = "0.2.3" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 6458809..7e0c5a1 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -33,3 +33,5 @@ humantime = "2.3.0" csv = "1.4.0" chrono = "0.4.44" crc = "3.4.0" +ort = "2.0.0-rc.12" +ndarray = "0.17.2" \ No newline at end of file diff --git a/src-tauri/resource/round_finish_once.mp3 b/src-tauri/assets/audio/round_finish_once.mp3 similarity index 100% rename from src-tauri/resource/round_finish_once.mp3 rename to src-tauri/assets/audio/round_finish_once.mp3 diff --git a/src-tauri/assets/models/classifier.onnx b/src-tauri/assets/models/classifier.onnx new file mode 100644 index 0000000..ab66033 Binary files /dev/null and b/src-tauri/assets/models/classifier.onnx differ diff --git a/src-tauri/assets/models/model_metadata.json b/src-tauri/assets/models/model_metadata.json new file mode 100644 index 0000000..1c9a71c --- /dev/null +++ b/src-tauri/assets/models/model_metadata.json @@ -0,0 +1,103 @@ +{ + "model_scope": "1d time-series tactile force compensation", + "channel_count": 84, + "feature_names": [ + "sensor_sum", + "sensor_mean", + "sensor_max", + "sensor_min", + "sensor_range", + "sensor_std", + "top5_sum", + "baseline_sum", + "delta_from_baseline", + "naive_force_n", + "slope_1", + "slope_5", + "slope_20", + "slope_100", + "mean_5", + "mean_20", + "mean_100", + "mean_300", + "short_minus_mid", + "mid_minus_long", + "age_above_touch" + ], + "feature_engineering": { + "current_row_stats": [ + "sensor_sum", + "sensor_mean", + "sensor_max", + "sensor_min", + "sensor_range", + "sensor_std", + "top5_sum" + ], + "rolling_mean_windows": [ + 5, + 20, + 100, + 300 + ], + "slope_windows": [ + 1, + 5, + 20, + 100 + ], + "topk_channels": 5, + "baseline_feature": "baseline_sum", + "age_above_touch_feature": "age_above_touch", + "naive_force_feature": "naive_force_n" + }, + "inference_pipeline": [ + "build 21 numeric features in the order listed by feature_names", + "run classifier.onnx or classifier.cbm", + "if predicted_state == 0 then output 0.0 N", + "otherwise run regressor.onnx or regressor.cbm and clamp to >= 0.0 N" + ], + "calibration_points": [ + { + "sensor_sum": 74602.73399170733, + "force_n": 0.9800000000000001 + }, + { + "sensor_sum": 105503.9038227644, + "force_n": 1.9600000000000002 + }, + { + "sensor_sum": 131459.57643184246, + "force_n": 2.94 + }, + { + "sensor_sum": 153512.34776297462, + "force_n": 3.9200000000000004 + }, + { + "sensor_sum": 172041.11212077862, + "force_n": 4.9 + }, + { + "sensor_sum": 193794.83789260528, + "force_n": 5.88 + }, + { + "sensor_sum": 218947.72467683573, + "force_n": 7.840000000000001 + }, + { + "sensor_sum": 240580.4449421614, + "force_n": 9.8 + } + ], + "training_args": { + "positive_step": 80, + "zero_step": 40, + "classifier_iterations": 500, + "regressor_iterations": 700, + "depth": 6, + "learning_rate": 0.05, + "random_seed": 42 + } +} \ No newline at end of file diff --git a/src-tauri/assets/models/regressor.onnx b/src-tauri/assets/models/regressor.onnx new file mode 100644 index 0000000..4903947 Binary files /dev/null and b/src-tauri/assets/models/regressor.onnx differ diff --git a/src-tauri/src/estimator.rs b/src-tauri/src/estimator.rs new file mode 100644 index 0000000..8b10ded --- /dev/null +++ b/src-tauri/src/estimator.rs @@ -0,0 +1,260 @@ +use anyhow::{Context, Result}; +use ndarray::Array2; +use ort::session::Session; +use ort::value::TensorRef; +use serde::Deserialize; +use std::collections::VecDeque; +use std::fs; +use std::path::{Path, PathBuf}; + +const CHANNEL_COUNT: usize = 84; +const FEATURE_COUNT: usize = 21; +const HISTORY_WINDOW: usize = 300; + +#[derive(Debug, Deserialize)] +struct CalibrationPoint { + sensor_sum: f32, + force_n: f32, +} + +#[derive(Debug, Deserialize)] +struct ModelMetadata { + feature_names: Vec, + calibration_points: Vec, +} + +pub struct ForceEstimator { + classifier: Session, + regressor: Session, + metadata: ModelMetadata, + sum_history: VecDeque, + baseline_buffer: Vec, + baseline_sum: f32, + baseline_ready: bool, + touched: bool, + age_above_touch: u32, + touch_delta_threshold: f32, +} + +impl ForceEstimator { + pub fn new(bundle_dir: impl AsRef) -> Result { + let bundle_dir = bundle_dir.as_ref(); + let metadata_path = bundle_dir.join("model_metadata.json"); + let classifier_path = bundle_dir.join("classifier.onnx"); + let regressor_path = bundle_dir.join("regressor.onnx"); + + let metadata: ModelMetadata = serde_json::from_slice( + &fs::read(&metadata_path) + .with_context(|| format!("failed to read {}", metadata_path.display()))?, + ) + .with_context(|| format!("failed to parse {}", metadata_path.display()))?; + + if metadata.feature_names.len() != FEATURE_COUNT { + anyhow::bail!( + "expected {} features, got {}", + FEATURE_COUNT, + metadata.feature_names.len() + ); + } + + let classifier = Session::builder()? + .commit_from_file(&classifier_path) + .with_context(|| format!("failed to load {}", classifier_path.display()))?; + let regressor = Session::builder()? + .commit_from_file(®ressor_path) + .with_context(|| format!("failed to load {}", regressor_path.display()))?; + + Ok(Self { + classifier, + regressor, + metadata, + sum_history: VecDeque::with_capacity(HISTORY_WINDOW), + baseline_buffer: Vec::with_capacity(100), + baseline_sum: 0.0, + baseline_ready: false, + touched: false, + age_above_touch: 0, + touch_delta_threshold: 1_000.0, + }) + } + + pub fn process_frame(&mut self, channels: [f32; CHANNEL_COUNT]) -> Result { + let sensor_sum: f32 = channels.iter().sum(); + + // Simple online baseline rule: + // collect the first 100 low-activity frames, then freeze the baseline. + if !self.baseline_ready && !self.touched { + self.baseline_buffer.push(sensor_sum); + if self.baseline_buffer.len() >= 100 { + let total: f32 = self.baseline_buffer.iter().sum(); + self.baseline_sum = total / self.baseline_buffer.len() as f32; + self.baseline_ready = true; + } + } + + if self.sum_history.len() == HISTORY_WINDOW { + self.sum_history.pop_front(); + } + self.sum_history.push_back(sensor_sum); + + let touch_threshold = self.baseline_sum + self.touch_delta_threshold; + if self.baseline_ready && sensor_sum >= touch_threshold { + self.touched = true; + self.age_above_touch += 1; + } else { + self.age_above_touch = 0; + } + + let features = self.build_features(&channels)?; + let predicted_state = self.predict_state(&features)?; + if predicted_state == 0 { + return Ok(0.0); + } + + let predicted_force = self.predict_force_value(&features)?; + Ok(predicted_force.max(0.0)) + } + + fn build_features(&self, channels: &[f32; CHANNEL_COUNT]) -> Result<[f32; FEATURE_COUNT]> { + let sensor_sum = *self + .sum_history + .back() + .context("cannot build features without at least one frame")?; + let sensor_mean = sensor_sum / CHANNEL_COUNT as f32; + + let mut sorted = channels.to_vec(); + sorted.sort_by(|a, b| b.partial_cmp(a).unwrap()); + let top5_sum: f32 = sorted.iter().take(5).sum(); + + let sensor_max = *sorted.first().unwrap_or(&0.0); + let sensor_min = channels + .iter() + .fold(f32::INFINITY, |acc, &value| acc.min(value)); + let sensor_range = sensor_max - sensor_min; + let sensor_std = population_std(channels, sensor_mean); + let naive_force_n = interpolate_force(sensor_sum, &self.metadata.calibration_points); + + let mean_5 = rolling_mean(&self.sum_history, 5); + let mean_20 = rolling_mean(&self.sum_history, 20); + let mean_100 = rolling_mean(&self.sum_history, 100); + let mean_300 = rolling_mean(&self.sum_history, 300); + + Ok([ + sensor_sum, + sensor_mean, + sensor_max, + sensor_min, + sensor_range, + sensor_std, + top5_sum, + self.baseline_sum, + sensor_sum - self.baseline_sum, + naive_force_n, + slope_from_history(&self.sum_history, 1), + slope_from_history(&self.sum_history, 5), + slope_from_history(&self.sum_history, 20), + slope_from_history(&self.sum_history, 100), + mean_5, + mean_20, + mean_100, + mean_300, + mean_5 - mean_20, + mean_20 - mean_100, + self.age_above_touch as f32, + ]) + } + + fn predict_state(&mut self, features: &[f32; FEATURE_COUNT]) -> Result { + let input = Array2::from_shape_vec((1, FEATURE_COUNT), features.to_vec())?; + let outputs = self + .classifier + .run(ort::inputs![TensorRef::from_array_view(&input)?]?)?; + let first = outputs + .iter() + .next() + .context("classifier returned no outputs")?; + let tensor = first.try_extract_tensor::()?; + let value = tensor + .view() + .iter() + .next() + .copied() + .context("classifier output tensor was empty")?; + Ok(value) + } + + fn predict_force_value(&mut self, features: &[f32; FEATURE_COUNT]) -> Result { + let input = Array2::from_shape_vec((1, FEATURE_COUNT), features.to_vec())?; + let outputs = self + .regressor + .run(ort::inputs![TensorRef::from_array_view(&input)?]?)?; + let first = outputs + .iter() + .next() + .context("regressor returned no outputs")?; + let tensor = first.try_extract_tensor::()?; + let value = tensor + .view() + .iter() + .next() + .copied() + .context("regressor output tensor was empty")?; + Ok(value) + } +} + +fn population_std(values: &[f32], mean: f32) -> f32 { + if values.len() <= 1 { + return 0.0; + } + let variance: f32 = values + .iter() + .map(|value| { + let diff = *value - mean; + diff * diff + }) + .sum::() + / values.len() as f32; + variance.sqrt() +} + +fn rolling_mean(history: &VecDeque, window: usize) -> f32 { + if history.is_empty() { + return 0.0; + } + let count = window.min(history.len()); + let sum: f32 = history.iter().rev().take(count).copied().sum(); + sum / count as f32 +} + +fn slope_from_history(history: &VecDeque, lookback: usize) -> f32 { + if history.is_empty() { + return 0.0; + } + let current = *history.back().unwrap(); + let previous_index = history.len().saturating_sub(lookback + 1); + let previous = history.get(previous_index).copied().unwrap_or(current); + let steps = (history.len() - 1).saturating_sub(previous_index).max(1); + (current - previous) / steps as f32 +} + +fn interpolate_force(sensor_sum: f32, points: &[CalibrationPoint]) -> f32 { + if points.is_empty() { + return 0.0; + } + if sensor_sum <= points[0].sensor_sum { + return points[0].force_n; + } + if sensor_sum >= points[points.len() - 1].sensor_sum { + return points[points.len() - 1].force_n; + } + for pair in points.windows(2) { + let left = &pair[0]; + let right = &pair[1]; + if sensor_sum >= left.sensor_sum && sensor_sum <= right.sensor_sum { + let ratio = (sensor_sum - left.sensor_sum) / (right.sensor_sum - left.sensor_sum); + return left.force_n + ratio * (right.force_n - left.force_n); + } + } + points[points.len() - 1].force_n +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 761fd73..2ad0477 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,4 +1,5 @@ mod commands; +pub mod estimator; pub mod log; pub mod serial_core; use commands::serial::SerialConnectionState;