From 28012e70e0ae630c0b1da5f533a5332594a42e86 Mon Sep 17 00:00:00 2001 From: spencer Date: Thu, 12 Mar 2026 13:35:17 -0400 Subject: [PATCH] feat: add storage and training modules for snore detection - Implemented `storage.py` for managing metadata storage, including sample addition, retrieval, and review state management. - Created `training.py` for training a local model using Random Forest, including functions for training and predicting samples. - Developed a web interface in `app.js` for capturing audio samples, managing labels, and training the model. - Added HTML structure in `index.html` for the SnoreStopper control room with sections for sample capture, overnight gathering, training, and status display. - Styled the application with `styles.css` to enhance user experience and interface aesthetics. --- .gitignore | 36 ++ README.md | 102 +++++ data/meta/.gitkeep | 0 data/models/.gitkeep | 0 data/raw/.gitkeep | 0 data/spectrograms/.gitkeep | 0 project_design.md | 12 + pyproject.toml | 28 ++ requirements.txt | 10 + src/snorestopper/__init__.py | 4 + src/snorestopper/audio.py | 68 ++++ src/snorestopper/config.py | 43 +++ src/snorestopper/features.py | 95 +++++ src/snorestopper/main.py | 555 +++++++++++++++++++++++++++ src/snorestopper/overnight.py | 167 ++++++++ src/snorestopper/schemas.py | 98 +++++ src/snorestopper/storage.py | 119 ++++++ src/snorestopper/training.py | 164 ++++++++ web/app.js | 689 ++++++++++++++++++++++++++++++++++ web/index.html | 107 ++++++ web/styles.css | 383 +++++++++++++++++++ 21 files changed, 2680 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 data/meta/.gitkeep create mode 100644 data/models/.gitkeep create mode 100644 data/raw/.gitkeep create mode 100644 data/spectrograms/.gitkeep create mode 100644 project_design.md create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 src/snorestopper/__init__.py create mode 100644 src/snorestopper/audio.py create mode 100644 src/snorestopper/config.py create mode 100644 src/snorestopper/features.py create mode 100644 src/snorestopper/main.py create mode 100644 src/snorestopper/overnight.py create mode 100644 src/snorestopper/schemas.py create mode 100644 src/snorestopper/storage.py create mode 100644 src/snorestopper/training.py create mode 100644 web/app.js create mode 100644 web/index.html create mode 100644 web/styles.css diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..58fad56 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +# Python caches and build artifacts +__pycache__/ +*.py[cod] +*.pyo +*.pyd +*.so +*.egg-info/ +build/ +dist/ + +# Virtual environments and local settings +.venv/ +venv/ +.env + +# Test and tooling caches +.pytest_cache/ +.mypy_cache/ +.coverage +htmlcov/ + +# VS Code local workspace settings +.vscode/* +!.vscode/extensions.json + +# Generated app data +data/raw/* +data/spectrograms/* +data/models/* +data/meta/*.json + +# Keep generated-data folder structure tracked +!data/raw/.gitkeep +!data/spectrograms/.gitkeep +!data/models/.gitkeep +!data/meta/.gitkeep diff --git a/README.md b/README.md new file mode 100644 index 0000000..2f9b286 --- /dev/null +++ b/README.md @@ -0,0 +1,102 @@ +# SnoreStopper v2 + +SnoreStopper is a self-hosted Python + web UI project for collecting night audio samples, labeling snore events, and training a local classifier. + +This starter build gives you: +- Audio input device discovery from the server machine +- On-demand audio sample capture to compressed FLAC files +- Spectrogram generation for each sample +- Browser-based labeling workflow (`snore`, `not_snore`, `unclear`) +- Local model training with scikit-learn +- Per-sample detection using the trained model +- Overnight sample gathering with configurable clip interval +- Snore watch proposal queue with thumbs up/down review gating before training + +## Project Layout + +```text +SnoreStopper_v2/ + data/ + raw/ # Recorded FLAC clips + spectrograms/ # PNG spectrogram images + models/ # Trained model artifacts + meta/ # Metadata JSON (samples, labels) + src/snorestopper/ + main.py # FastAPI app and endpoints + audio.py # Device listing and recording + features.py # Audio IO + spectrogram + features + training.py # Local model training/prediction + storage.py # Metadata persistence + schemas.py # API request/response models + web/ + index.html + styles.css + app.js + requirements.txt + pyproject.toml +``` + +## Quick Start (Windows PowerShell) + +1. Create and activate a virtual environment. +```powershell +python -m venv .venv +.\.venv\Scripts\Activate.ps1 +``` + +2. Install dependencies. +```powershell +python -m pip install --upgrade pip +pip install -r requirements.txt +``` + +3. Start the local server. +```powershell +uvicorn --app-dir src snorestopper.main:app --reload --host 127.0.0.1 --port 8000 +``` + +4. Open the app. +- Browser: `http://127.0.0.1:8000` +- API docs: `http://127.0.0.1:8000/docs` + +## Typical Workflow + +1. Choose an input device and record short clips over several nights. +2. Review each clip and spectrogram. +3. Label clips in the UI. +4. Train the model locally from your labeled dataset. +5. Start an overnight run with `auto_watch` enabled to capture clips and queue predictions. +6. Approve or invert watch proposals with thumbs up/down in the review queue. +7. Retrain the model from approved labels. + +## Overnight + Snore Watch Flow + +- `Overnight Gatherer` captures clips on a fixed interval for N hours. +- If `auto_watch` is enabled, each clip is scored by the local model. +- Predictions are stored as pending proposals and are **not** used for training yet. +- You decide per clip: + - `Thumbs Up` -> proposal becomes approved training label + - `Thumbs Down` -> proposal is inverted (`snore <-> not_snore`) and approved for training +- Manual labels still work and can override watch proposals. + +## Notes + +- This is intentionally self-hosted and local-first: all recorded data, labels, and model artifacts stay on your machine. +- The current model is a baseline (RandomForest + handcrafted spectral features) so you can get to a working loop quickly. +- Recording quality and label quality are the main drivers of model performance. + +## Environment Variables (Optional) + +- `SNORESTOPPER_ROOT`: Override project root directory +- `SNORESTOPPER_SAMPLE_RATE`: Default `16000` +- `SNORESTOPPER_CHANNELS`: Default `1` +- `SNORESTOPPER_MIN_DURATION`: Default `2` +- `SNORESTOPPER_MAX_DURATION`: Default `90` +- `SNORESTOPPER_MODEL_FILE`: Default `snore_classifier.joblib` + +## Next Build Targets + +- Scheduled overnight capture jobs +- Better event segmentation and confidence thresholds +- Hardware trigger module for anti-snoring actions +- User profiles and per-user local models diff --git a/data/meta/.gitkeep b/data/meta/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/models/.gitkeep b/data/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/raw/.gitkeep b/data/raw/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/spectrograms/.gitkeep b/data/spectrograms/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/project_design.md b/project_design.md new file mode 100644 index 0000000..0062d58 --- /dev/null +++ b/project_design.md @@ -0,0 +1,12 @@ +I want to build an app based on Python and a web UI + +The project goal is to build an all in one solution for tracking snooring and tracking sleep. + +to do this we need to allow the end user to +- select an audio input device from the ones connected to the 'server' +- to record ambiant room noise and gather samples of snoring + - to do this we'll need to define good sample sizes and periods of recording + - we'll also need to save the files in a audio format that doesn't pin the CPU/gpu but also doesn't waste filesize pointessly. + - We will likely also want to store the produced data from the sample that will be used by the network for training (i am thinking a basic spectrogram stored as a png) + - User will record samples over a couple nights, find examples of snoring, train the AI, let it try flagging the snoring events itself and once it's done trigger the anti-snoring events +- must be self-hosted and self-trained, designed to be safe and appless so people are comfortable with passive recording diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c29e4e3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "snorestopper" +version = "0.1.0" +description = "Self-hosted snore tracking and snore event detection" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "fastapi>=0.115,<1.0", + "uvicorn[standard]>=0.30,<1.0", + "numpy>=2.0,<3.0", + "sounddevice>=0.5,<1.0", + "soundfile>=0.12,<1.0", + "scipy>=1.13,<2.0", + "matplotlib>=3.9,<4.0", + "scikit-learn>=1.5,<2.0", + "joblib>=1.4,<2.0", + "python-multipart>=0.0.9,<1.0", +] + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f3dd749 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +fastapi>=0.115,<1.0 +uvicorn[standard]>=0.30,<1.0 +numpy>=2.0,<3.0 +sounddevice>=0.5,<1.0 +soundfile>=0.12,<1.0 +scipy>=1.13,<2.0 +matplotlib>=3.9,<4.0 +scikit-learn>=1.5,<2.0 +joblib>=1.4,<2.0 +python-multipart>=0.0.9,<1.0 diff --git a/src/snorestopper/__init__.py b/src/snorestopper/__init__.py new file mode 100644 index 0000000..ed8db0d --- /dev/null +++ b/src/snorestopper/__init__.py @@ -0,0 +1,4 @@ +"""SnoreStopper package.""" + +__all__ = ["__version__"] +__version__ = "0.1.0" diff --git a/src/snorestopper/audio.py b/src/snorestopper/audio.py new file mode 100644 index 0000000..395a08c --- /dev/null +++ b/src/snorestopper/audio.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +try: + import sounddevice as sd +except Exception: + sd = None + + +def _require_sounddevice() -> None: + if sd is None: + raise RuntimeError( + "sounddevice is unavailable. Install requirements and ensure PortAudio is available." + ) + + +def list_input_devices() -> list[dict[str, Any]]: + _require_sounddevice() + devices = sd.query_devices() + + input_devices: list[dict[str, Any]] = [] + for index, device in enumerate(devices): + max_input_channels = int(device.get("max_input_channels", 0)) + if max_input_channels <= 0: + continue + + input_devices.append( + { + "index": index, + "name": str(device.get("name", f"Input Device {index}")), + "max_input_channels": max_input_channels, + "default_samplerate": float(device.get("default_samplerate", 0.0)), + } + ) + + return input_devices + + +def capture_audio( + duration_seconds: int, + sample_rate: int, + channels: int, + device_index: int | None = None, +) -> np.ndarray: + _require_sounddevice() + + if duration_seconds <= 0: + raise RuntimeError("duration_seconds must be positive") + + frame_count = int(duration_seconds * sample_rate) + recording = sd.rec( + frame_count, + samplerate=sample_rate, + channels=channels, + dtype="float32", + device=device_index, + ) + sd.wait() + + audio = np.asarray(recording, dtype=np.float32) + if audio.ndim == 2: + # Collapse to mono for simpler feature extraction and training. + audio = np.mean(audio, axis=1) + + return audio diff --git a/src/snorestopper/config.py b/src/snorestopper/config.py new file mode 100644 index 0000000..d8b319e --- /dev/null +++ b/src/snorestopper/config.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import os + + +@dataclass(frozen=True) +class Settings: + project_root: Path + data_dir: Path + raw_dir: Path + spectrogram_dir: Path + models_dir: Path + metadata_dir: Path + web_dir: Path + sample_rate: int + channels: int + min_duration_seconds: int + max_duration_seconds: int + model_file_name: str + + +def load_settings() -> Settings: + project_root = Path( + os.getenv("SNORESTOPPER_ROOT", Path(__file__).resolve().parents[2]) + ) + data_dir = project_root / "data" + + return Settings( + project_root=project_root, + data_dir=data_dir, + raw_dir=data_dir / "raw", + spectrogram_dir=data_dir / "spectrograms", + models_dir=data_dir / "models", + metadata_dir=data_dir / "meta", + web_dir=project_root / "web", + sample_rate=int(os.getenv("SNORESTOPPER_SAMPLE_RATE", "16000")), + channels=int(os.getenv("SNORESTOPPER_CHANNELS", "1")), + min_duration_seconds=int(os.getenv("SNORESTOPPER_MIN_DURATION", "2")), + max_duration_seconds=int(os.getenv("SNORESTOPPER_MAX_DURATION", "90")), + model_file_name=os.getenv("SNORESTOPPER_MODEL_FILE", "snore_classifier.joblib"), + ) diff --git a/src/snorestopper/features.py b/src/snorestopper/features.py new file mode 100644 index 0000000..a59b801 --- /dev/null +++ b/src/snorestopper/features.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import soundfile as sf +from scipy import signal + + +def save_audio_as_flac(audio: np.ndarray, sample_rate: int, output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + sf.write(file=str(output_path), data=audio, samplerate=sample_rate, format="FLAC") + + +def read_audio(audio_path: Path) -> tuple[np.ndarray, int]: + audio, sample_rate = sf.read(str(audio_path), always_2d=False) + waveform = np.asarray(audio, dtype=np.float32) + if waveform.ndim == 2: + waveform = np.mean(waveform, axis=1) + return waveform, int(sample_rate) + + +def _compute_spectrogram(audio: np.ndarray, sample_rate: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + n_samples = int(audio.shape[0]) + if n_samples <= 1: + nperseg = 1 + noverlap = 0 + else: + nperseg = min(1024, n_samples) + noverlap = min(int(nperseg * 0.5), nperseg - 1) + + return signal.spectrogram( + audio, + fs=sample_rate, + nperseg=nperseg, + noverlap=noverlap, + scaling="spectrum", + mode="magnitude", + ) + + +def save_spectrogram(audio: np.ndarray, sample_rate: int, output_path: Path) -> None: + frequencies, times, spectrum = _compute_spectrogram(audio, sample_rate) + power_db = 10.0 * np.log10(spectrum + 1e-12) + + output_path.parent.mkdir(parents=True, exist_ok=True) + fig, axis = plt.subplots(figsize=(8, 3)) + mesh = axis.pcolormesh(times, frequencies, power_db, shading="gouraud", cmap="magma") + axis.set_title("SnoreStopper Spectrogram") + axis.set_ylabel("Frequency (Hz)") + axis.set_xlabel("Time (s)") + fig.colorbar(mesh, ax=axis, label="Power (dB)") + fig.tight_layout() + fig.savefig(output_path, dpi=120) + plt.close(fig) + + +def _band_energy(power: np.ndarray, frequencies: np.ndarray, low: float, high: float) -> float: + mask = (frequencies >= low) & (frequencies < high) + if not np.any(mask): + return 0.0 + return float(np.mean(power[mask, :])) + + +def extract_feature_vector(audio: np.ndarray, sample_rate: int) -> np.ndarray: + frequencies, _, spectrum = _compute_spectrogram(audio, sample_rate) + power = np.log1p(spectrum) + + mean_power = float(np.mean(power)) + std_power = float(np.std(power)) + max_power = float(np.max(power)) + + low_band = _band_energy(power, frequencies, 20.0, 250.0) + mid_band = _band_energy(power, frequencies, 250.0, 1000.0) + high_band = _band_energy(power, frequencies, 1000.0, 4000.0) + + spectral_weights = np.mean(power, axis=1) + 1e-9 + spectral_centroid = float(np.average(frequencies, weights=spectral_weights)) + + return np.asarray( + [ + mean_power, + std_power, + max_power, + low_band, + mid_band, + high_band, + spectral_centroid, + ], + dtype=np.float32, + ) diff --git a/src/snorestopper/main.py b/src/snorestopper/main.py new file mode 100644 index 0000000..b915053 --- /dev/null +++ b/src/snorestopper/main.py @@ -0,0 +1,555 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Any +from uuid import uuid4 + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.staticfiles import StaticFiles + +from .audio import capture_audio, list_input_devices +from .config import Settings, load_settings +from .features import save_audio_as_flac, save_spectrogram +from .overnight import OvernightCaptureController +from .schemas import ( + ApiMessage, + CaptureRequest, + DetectRequest, + DetectResponse, + DeviceInfo, + LabelRequest, + OvernightStartRequest, + OvernightStatusResponse, + SampleRecord, + TrainRequest, + TrainResponse, +) +from .storage import MetadataStore, ensure_directories +from .training import predict_sample, train_local_model + +settings = load_settings() +metadata_file = ensure_directories(settings) +store = MetadataStore(metadata_file) + +BINARY_LABELS = {"snore", "not_snore"} +VALID_SAMPLE_LABELS = {"snore", "not_snore", "unclear"} +REVIEW_STATES = {"none", "pending", "approved", "rejected", "manual"} +LABEL_SOURCES = {"none", "manual", "watch"} + +app = FastAPI( + title="SnoreStopper", + version="0.1.0", + description="Self-hosted snore tracking, labeling, and model training", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.mount("/web", StaticFiles(directory=settings.web_dir), name="web") +app.mount("/data", StaticFiles(directory=settings.data_dir), name="data") + + +def _to_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _to_optional_float(value: Any) -> float | None: + if value is None: + return None + + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _invert_binary_label(label: str) -> str: + if label == "snore": + return "not_snore" + if label == "not_snore": + return "snore" + raise ValueError(f"Unsupported binary label: {label}") + + +def _to_sample_record(record: dict[str, Any]) -> SampleRecord: + audio_file = str(record.get("audio_file", "")) + spectrogram_file = str(record.get("spectrogram_file", "")) + label = record.get("label") + if label not in VALID_SAMPLE_LABELS: + label = None + + proposed_label = record.get("proposed_label") + if proposed_label not in BINARY_LABELS: + proposed_label = None + + review_state = str(record.get("review_state", "none")) + if review_state not in REVIEW_STATES: + review_state = "none" + + label_source = str(record.get("label_source", "none")) + if label_source not in LABEL_SOURCES: + label_source = "none" + + return SampleRecord( + sample_id=str(record.get("sample_id", "")), + created_at=str(record.get("created_at", "")), + duration_seconds=_to_int(record.get("duration_seconds"), default=0), + sample_rate=_to_int(record.get("sample_rate"), default=settings.sample_rate), + device_index=record.get("device_index"), + tag=record.get("tag"), + label=label, + audio_file=audio_file, + spectrogram_file=spectrogram_file, + audio_url=f"/data/raw/{audio_file}", + spectrogram_url=f"/data/spectrograms/{spectrogram_file}", + proposed_label=proposed_label, + proposed_confidence=_to_optional_float(record.get("proposed_confidence")), + review_state=review_state, + training_approved=bool(record.get("training_approved", False)), + label_source=label_source, + ) + + +def _resolve_model_path(model_path_value: str | None) -> Path: + if not model_path_value: + return settings.models_dir / settings.model_file_name + + candidate = Path(model_path_value).expanduser() + if not candidate.is_absolute(): + candidate = settings.project_root / candidate + return candidate + + +def _validate_capture_duration_seconds(duration_seconds: int, app_settings: Settings) -> None: + if duration_seconds < app_settings.min_duration_seconds: + raise HTTPException( + status_code=400, + detail=( + f"duration_seconds must be at least {app_settings.min_duration_seconds} " + "seconds" + ), + ) + + if duration_seconds > app_settings.max_duration_seconds: + raise HTTPException( + status_code=400, + detail=( + f"duration_seconds must be {app_settings.max_duration_seconds} " + "seconds or less" + ), + ) + + +def _capture_and_store_sample( + *, + duration_seconds: int, + device_index: int | None, + tag: str | None, + watch_after_capture: bool, +) -> dict[str, Any]: + _validate_capture_duration_seconds(duration_seconds, settings) + + try: + waveform = capture_audio( + duration_seconds=duration_seconds, + sample_rate=settings.sample_rate, + channels=settings.channels, + device_index=device_index, + ) + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Recording failed: {exc}") from exc + + sample_id = uuid4().hex + audio_file = f"{sample_id}.flac" + spectrogram_file = f"{sample_id}.png" + + audio_path = settings.raw_dir / audio_file + spectrogram_path = settings.spectrogram_dir / spectrogram_file + + try: + save_audio_as_flac(waveform, settings.sample_rate, audio_path) + save_spectrogram(waveform, settings.sample_rate, spectrogram_path) + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Saving sample failed: {exc}") from exc + + record: dict[str, Any] = { + "sample_id": sample_id, + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "duration_seconds": duration_seconds, + "sample_rate": settings.sample_rate, + "device_index": device_index, + "tag": tag, + "label": None, + "audio_file": audio_file, + "spectrogram_file": spectrogram_file, + "proposed_label": None, + "proposed_confidence": None, + "review_state": "none", + "training_approved": False, + "label_source": "none", + } + + if watch_after_capture: + model_path = settings.models_dir / settings.model_file_name + if model_path.exists(): + try: + prediction = predict_sample( + audio_file=audio_file, + raw_dir=settings.raw_dir, + model_path=model_path, + ) + if prediction.predicted_label in BINARY_LABELS: + record.update( + { + "proposed_label": prediction.predicted_label, + "proposed_confidence": prediction.confidence, + "review_state": "pending", + "training_approved": False, + "label_source": "watch", + } + ) + except Exception: + # Capture should still be retained even when watch inference fails. + pass + + store.add_sample(record) + return record + + +def _background_capture_job(payload: dict[str, Any]) -> dict[str, Any]: + return _capture_and_store_sample( + duration_seconds=_to_int(payload.get("duration_seconds"), default=10), + device_index=payload.get("device_index"), + tag=str(payload.get("tag")) if payload.get("tag") is not None else None, + watch_after_capture=bool(payload.get("watch_after_capture", False)), + ) + + +overnight_controller = OvernightCaptureController( + run_capture_job=_background_capture_job, + pending_review_count=store.count_pending_review_samples, +) + + +@app.get("/", include_in_schema=False) +def root_redirect() -> RedirectResponse: + return RedirectResponse(url="/web/index.html") + + +@app.get("/api/health") +def health() -> dict[str, Any]: + return { + "status": "ok", + "project_root": str(settings.project_root), + "sample_rate": settings.sample_rate, + "channels": settings.channels, + } + + +@app.get("/api/audio/devices", response_model=list[DeviceInfo]) +def get_audio_devices() -> list[DeviceInfo]: + try: + devices = list_input_devices() + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Unable to query input devices: {exc}") from exc + + return [DeviceInfo(**item) for item in devices] + + +@app.get("/api/samples", response_model=list[SampleRecord]) +def list_samples() -> list[SampleRecord]: + return [_to_sample_record(item) for item in store.list_samples()] + + +@app.post("/api/samples/capture", response_model=SampleRecord) +def capture_sample(payload: CaptureRequest) -> SampleRecord: + record = _capture_and_store_sample( + duration_seconds=payload.duration_seconds, + device_index=payload.device_index, + tag=payload.tag, + watch_after_capture=payload.watch_after_capture, + ) + return _to_sample_record(record) + + +@app.post("/api/samples/{sample_id}/label", response_model=SampleRecord) +def label_sample(sample_id: str, payload: LabelRequest) -> SampleRecord: + training_approved = payload.label in BINARY_LABELS + try: + updated = store.update_sample_fields( + sample_id=sample_id, + updates={ + "label": payload.label, + "review_state": "manual", + "training_approved": training_approved, + "label_source": "manual", + "proposed_label": None, + "proposed_confidence": None, + }, + ) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + return _to_sample_record(updated) + + +@app.post("/api/watch/{sample_id}/propose", response_model=SampleRecord) +def propose_watch_label(sample_id: str) -> SampleRecord: + sample = store.get_sample(sample_id) + if sample is None: + raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found") + + if bool(sample.get("training_approved", False)): + raise HTTPException( + status_code=409, + detail="Sample is already approved for training. Clear or relabel it before watch proposal.", + ) + + model_path = settings.models_dir / settings.model_file_name + if not model_path.exists(): + raise HTTPException( + status_code=404, + detail=( + "No trained model found. Train a model first, then run snore watch proposals." + ), + ) + + try: + prediction = predict_sample( + audio_file=str(sample.get("audio_file", "")), + raw_dir=settings.raw_dir, + model_path=model_path, + ) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Watch proposal failed: {exc}") from exc + + if prediction.predicted_label not in BINARY_LABELS: + raise HTTPException( + status_code=500, + detail=( + "Model predicted an unsupported class. Retrain with labels 'snore' and 'not_snore'." + ), + ) + + try: + updated = store.update_sample_fields( + sample_id=sample_id, + updates={ + "label": None, + "proposed_label": prediction.predicted_label, + "proposed_confidence": prediction.confidence, + "review_state": "pending", + "training_approved": False, + "label_source": "watch", + }, + ) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + return _to_sample_record(updated) + + +@app.get("/api/watch/pending", response_model=list[SampleRecord]) +def list_pending_watch_reviews() -> list[SampleRecord]: + pending_samples = store.list_pending_review_samples() + return [_to_sample_record(item) for item in pending_samples] + + +@app.post("/api/watch/{sample_id}/thumbs-up", response_model=SampleRecord) +def approve_watch_prediction(sample_id: str) -> SampleRecord: + sample = store.get_sample(sample_id) + if sample is None: + raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found") + + proposed_label = sample.get("proposed_label") + if proposed_label not in BINARY_LABELS: + raise HTTPException( + status_code=400, + detail="Sample has no pending watch proposal to approve", + ) + + try: + updated = store.update_sample_fields( + sample_id=sample_id, + updates={ + "label": proposed_label, + "review_state": "approved", + "training_approved": True, + "label_source": "watch", + }, + ) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + return _to_sample_record(updated) + + +@app.post("/api/watch/{sample_id}/thumbs-down", response_model=SampleRecord) +def reject_watch_prediction(sample_id: str) -> SampleRecord: + sample = store.get_sample(sample_id) + if sample is None: + raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found") + + if sample.get("review_state") != "pending": + raise HTTPException( + status_code=400, + detail="Sample is not currently pending watch review", + ) + + proposed_label = sample.get("proposed_label") + if proposed_label not in BINARY_LABELS: + raise HTTPException( + status_code=400, + detail="Sample has no pending watch proposal to invert", + ) + + try: + corrected_label = _invert_binary_label(str(proposed_label)) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + try: + updated = store.update_sample_fields( + sample_id=sample_id, + updates={ + "label": corrected_label, + "review_state": "approved", + "training_approved": True, + "label_source": "watch", + }, + ) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + return _to_sample_record(updated) + + +@app.post("/api/train", response_model=TrainResponse) +def train_model(payload: TrainRequest) -> TrainResponse: + model_path = settings.models_dir / settings.model_file_name + samples = store.list_samples() + + try: + result = train_local_model( + samples=samples, + raw_dir=settings.raw_dir, + model_path=model_path, + test_size=payload.test_size, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Training failed: {exc}") from exc + + return TrainResponse( + model_path=result.model_path, + trained_samples=result.trained_samples, + classes=result.classes, + accuracy=result.accuracy, + ) + + +@app.get("/api/overnight/status", response_model=OvernightStatusResponse) +def overnight_status() -> OvernightStatusResponse: + return OvernightStatusResponse(**overnight_controller.get_status()) + + +@app.post("/api/overnight/start", response_model=OvernightStatusResponse) +def start_overnight_capture(payload: OvernightStartRequest) -> OvernightStatusResponse: + _validate_capture_duration_seconds(payload.clip_duration_seconds, settings) + + if payload.interval_seconds < payload.clip_duration_seconds: + raise HTTPException( + status_code=400, + detail="interval_seconds must be greater than or equal to clip_duration_seconds", + ) + + if payload.auto_watch: + model_path = settings.models_dir / settings.model_file_name + if not model_path.exists(): + raise HTTPException( + status_code=400, + detail=( + "auto_watch is enabled but no trained model was found. Train a model " + "first or disable auto_watch." + ), + ) + + try: + status = overnight_controller.start_session( + duration_hours=payload.duration_hours, + clip_duration_seconds=payload.clip_duration_seconds, + interval_seconds=payload.interval_seconds, + device_index=payload.device_index, + tag_prefix=payload.tag_prefix, + auto_watch=payload.auto_watch, + ) + except RuntimeError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + + return OvernightStatusResponse(**status) + + +@app.post("/api/overnight/stop", response_model=OvernightStatusResponse) +def stop_overnight_capture() -> OvernightStatusResponse: + return OvernightStatusResponse(**overnight_controller.stop_session()) + + +@app.post("/api/detect", response_model=DetectResponse) +def detect_sample(payload: DetectRequest) -> DetectResponse: + sample = store.get_sample(payload.sample_id) + if sample is None: + raise HTTPException(status_code=404, detail=f"Sample '{payload.sample_id}' was not found") + + model_path = _resolve_model_path(payload.model_path) + try: + prediction = predict_sample( + audio_file=str(sample["audio_file"]), + raw_dir=settings.raw_dir, + model_path=model_path, + ) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Detection failed: {exc}") from exc + + return DetectResponse( + sample_id=payload.sample_id, + predicted_label=prediction.predicted_label, + confidence=prediction.confidence, + model_path=str(model_path), + ) + + +@app.post("/api/reset", response_model=ApiMessage) +def reset_sample_labels() -> ApiMessage: + # Keep recordings but reset labels to quickly restart training experiments. + samples = store.list_samples() + for sample in samples: + if sample.get("sample_id"): + store.update_sample_fields( + sample_id=str(sample["sample_id"]), + updates={ + "label": None, + "proposed_label": None, + "proposed_confidence": None, + "review_state": "none", + "training_approved": False, + "label_source": "none", + }, + ) + + return ApiMessage(message="All sample labels and watch proposals were reset") diff --git a/src/snorestopper/overnight.py b/src/snorestopper/overnight.py new file mode 100644 index 0000000..3a7c23b --- /dev/null +++ b/src/snorestopper/overnight.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from threading import Event, Lock, Thread +from typing import Any, Callable +from uuid import uuid4 + + +class OvernightCaptureController: + def __init__( + self, + run_capture_job: Callable[[dict[str, Any]], dict[str, Any]], + pending_review_count: Callable[[], int], + ): + self._run_capture_job = run_capture_job + self._pending_review_count = pending_review_count + self._lock = Lock() + self._stop_event = Event() + self._thread: Thread | None = None + self._state: dict[str, Any] = { + "running": False, + "session_id": None, + "started_at": None, + "planned_end_at": None, + "next_capture_at": None, + "duration_hours": None, + "clip_duration_seconds": None, + "interval_seconds": None, + "auto_watch": None, + "captured_samples": 0, + "last_error": None, + } + + def start_session( + self, + *, + duration_hours: int, + clip_duration_seconds: int, + interval_seconds: int, + device_index: int | None, + tag_prefix: str, + auto_watch: bool, + ) -> dict[str, Any]: + with self._lock: + if self._state["running"]: + raise RuntimeError("Overnight capture is already running") + + session_id = uuid4().hex + started_at = datetime.now(tz=timezone.utc) + planned_end_at = started_at + timedelta(hours=duration_hours) + + self._stop_event = Event() + self._state = { + "running": True, + "session_id": session_id, + "started_at": started_at.isoformat(), + "planned_end_at": planned_end_at.isoformat(), + "next_capture_at": started_at.isoformat(), + "duration_hours": duration_hours, + "clip_duration_seconds": clip_duration_seconds, + "interval_seconds": interval_seconds, + "auto_watch": auto_watch, + "captured_samples": 0, + "last_error": None, + } + + worker = Thread( + target=self._run_loop, + kwargs={ + "session_id": session_id, + "planned_end_at": planned_end_at, + "clip_duration_seconds": clip_duration_seconds, + "interval_seconds": interval_seconds, + "device_index": device_index, + "tag_prefix": tag_prefix, + "auto_watch": auto_watch, + "stop_event": self._stop_event, + }, + daemon=True, + name="snorestopper-overnight-capture", + ) + self._thread = worker + worker.start() + + return self._status_locked() + + def stop_session(self) -> dict[str, Any]: + thread_to_join: Thread | None = None + with self._lock: + if not self._state["running"]: + return self._status_locked() + + self._stop_event.set() + thread_to_join = self._thread + + if thread_to_join is not None: + thread_to_join.join(timeout=2.0) + + with self._lock: + return self._status_locked() + + def get_status(self) -> dict[str, Any]: + with self._lock: + return self._status_locked() + + def _status_locked(self) -> dict[str, Any]: + status = dict(self._state) + status["pending_reviews"] = int(self._pending_review_count()) + return status + + def _run_loop( + self, + *, + session_id: str, + planned_end_at: datetime, + clip_duration_seconds: int, + interval_seconds: int, + device_index: int | None, + tag_prefix: str, + auto_watch: bool, + stop_event: Event, + ) -> None: + sample_number = 0 + next_capture_at = datetime.now(tz=timezone.utc) + + try: + while not stop_event.is_set(): + now = datetime.now(tz=timezone.utc) + if now >= planned_end_at: + break + + wait_seconds = (next_capture_at - now).total_seconds() + if wait_seconds > 0: + stop_event.wait(wait_seconds) + continue + + sample_number += 1 + capture_started_at = datetime.now(tz=timezone.utc) + tag = f"{tag_prefix}-{session_id[:6]}-{sample_number:04d}" + + try: + self._run_capture_job( + { + "duration_seconds": clip_duration_seconds, + "device_index": device_index, + "tag": tag, + "watch_after_capture": auto_watch, + } + ) + with self._lock: + self._state["captured_samples"] = sample_number + self._state["last_error"] = None + except Exception as exc: + with self._lock: + self._state["last_error"] = str(exc) + + next_capture_at = capture_started_at + timedelta(seconds=interval_seconds) + with self._lock: + if next_capture_at < planned_end_at: + self._state["next_capture_at"] = next_capture_at.isoformat() + else: + self._state["next_capture_at"] = None + finally: + with self._lock: + self._state["running"] = False + self._state["next_capture_at"] = None + self._thread = None diff --git a/src/snorestopper/schemas.py b/src/snorestopper/schemas.py new file mode 100644 index 0000000..49e837c --- /dev/null +++ b/src/snorestopper/schemas.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + +SampleLabel = Literal["snore", "not_snore", "unclear"] +BinaryLabel = Literal["snore", "not_snore"] +ReviewState = Literal["none", "pending", "approved", "rejected", "manual"] +LabelSource = Literal["none", "manual", "watch"] + + +class DeviceInfo(BaseModel): + index: int + name: str + max_input_channels: int + default_samplerate: float + + +class CaptureRequest(BaseModel): + duration_seconds: int = Field(default=10, ge=1, le=600) + device_index: int | None = None + tag: str | None = Field(default=None, max_length=80) + watch_after_capture: bool = False + + +class SampleRecord(BaseModel): + sample_id: str + created_at: str + duration_seconds: int + sample_rate: int + device_index: int | None = None + tag: str | None = None + label: SampleLabel | None = None + audio_file: str + spectrogram_file: str + audio_url: str + spectrogram_url: str + proposed_label: BinaryLabel | None = None + proposed_confidence: float | None = None + review_state: ReviewState = "none" + training_approved: bool = False + label_source: LabelSource = "none" + + +class LabelRequest(BaseModel): + label: SampleLabel + + +class TrainRequest(BaseModel): + test_size: float = Field(default=0.25, ge=0.1, le=0.5) + + +class TrainResponse(BaseModel): + model_path: str + trained_samples: int + classes: list[str] + accuracy: float | None = None + + +class DetectRequest(BaseModel): + sample_id: str + model_path: str | None = None + + +class DetectResponse(BaseModel): + sample_id: str + predicted_label: str + confidence: float + model_path: str + + +class OvernightStartRequest(BaseModel): + duration_hours: int = Field(default=8, ge=1, le=12) + clip_duration_seconds: int = Field(default=20, ge=2, le=600) + interval_seconds: int = Field(default=30, ge=2, le=600) + device_index: int | None = None + tag_prefix: str = Field(default="overnight", max_length=80) + auto_watch: bool = True + + +class OvernightStatusResponse(BaseModel): + running: bool + session_id: str | None = None + started_at: str | None = None + planned_end_at: str | None = None + next_capture_at: str | None = None + duration_hours: int | None = None + clip_duration_seconds: int | None = None + interval_seconds: int | None = None + auto_watch: bool | None = None + captured_samples: int = 0 + pending_reviews: int = 0 + last_error: str | None = None + + +class ApiMessage(BaseModel): + message: str diff --git a/src/snorestopper/storage.py b/src/snorestopper/storage.py new file mode 100644 index 0000000..818eb8e --- /dev/null +++ b/src/snorestopper/storage.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import json +from pathlib import Path +from threading import Lock +from typing import Any + +from .config import Settings + + +def ensure_directories(settings: Settings) -> Path: + for path in ( + settings.data_dir, + settings.raw_dir, + settings.spectrogram_dir, + settings.models_dir, + settings.metadata_dir, + ): + path.mkdir(parents=True, exist_ok=True) + + metadata_file = settings.metadata_dir / "samples.json" + if not metadata_file.exists(): + metadata_file.write_text("[]", encoding="utf-8") + return metadata_file + + +class MetadataStore: + def __init__(self, metadata_file: Path): + self.metadata_file = metadata_file + self._lock = Lock() + self.metadata_file.parent.mkdir(parents=True, exist_ok=True) + if not self.metadata_file.exists(): + self.metadata_file.write_text("[]", encoding="utf-8") + + def _read(self) -> list[dict[str, Any]]: + if not self.metadata_file.exists(): + return [] + + raw = self.metadata_file.read_text(encoding="utf-8").strip() + if not raw: + return [] + + try: + payload = json.loads(raw) + except json.JSONDecodeError: + return [] + + if not isinstance(payload, list): + return [] + + return [item for item in payload if isinstance(item, dict)] + + def _write(self, records: list[dict[str, Any]]) -> None: + self.metadata_file.write_text( + json.dumps(records, indent=2, ensure_ascii=True), + encoding="utf-8", + ) + + def list_samples(self) -> list[dict[str, Any]]: + with self._lock: + records = self._read() + + return sorted(records, key=lambda item: item.get("created_at", ""), reverse=True) + + def get_sample(self, sample_id: str) -> dict[str, Any] | None: + with self._lock: + records = self._read() + + for record in records: + if record.get("sample_id") == sample_id: + return record + return None + + def add_sample(self, record: dict[str, Any]) -> dict[str, Any]: + with self._lock: + records = self._read() + records.append(record) + self._write(records) + + return record + + def update_label(self, sample_id: str, label: str) -> dict[str, Any]: + with self._lock: + records = self._read() + for record in records: + if record.get("sample_id") == sample_id: + record["label"] = label + self._write(records) + return record + + raise KeyError(f"Sample '{sample_id}' was not found") + + def update_sample_fields( + self, + sample_id: str, + updates: dict[str, Any], + ) -> dict[str, Any]: + with self._lock: + records = self._read() + for record in records: + if record.get("sample_id") == sample_id: + record.update(updates) + self._write(records) + return record + + raise KeyError(f"Sample '{sample_id}' was not found") + + def list_pending_review_samples(self) -> list[dict[str, Any]]: + with self._lock: + records = self._read() + + pending = [item for item in records if item.get("review_state") == "pending"] + return sorted(pending, key=lambda item: item.get("created_at", ""), reverse=True) + + def count_pending_review_samples(self) -> int: + with self._lock: + records = self._read() + + return sum(1 for item in records if item.get("review_state") == "pending") diff --git a/src/snorestopper/training.py b/src/snorestopper/training.py new file mode 100644 index 0000000..03ba586 --- /dev/null +++ b/src/snorestopper/training.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import joblib +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder + +from .features import extract_feature_vector, read_audio + + +def _is_training_eligible(sample: dict[str, Any]) -> bool: + label = sample.get("label") + if label not in {"snore", "not_snore"}: + return False + + approved_flag = sample.get("training_approved") + if approved_flag is None: + # Backward compatibility for samples created before review gating existed. + return True + + return bool(approved_flag) + + +@dataclass(frozen=True) +class TrainingResult: + model_path: str + trained_samples: int + classes: list[str] + accuracy: float | None + + +@dataclass(frozen=True) +class PredictionResult: + predicted_label: str + confidence: float + + +def _build_training_matrix( + samples: list[dict[str, Any]], + raw_dir: Path, +) -> tuple[np.ndarray, list[str]]: + features: list[np.ndarray] = [] + labels: list[str] = [] + + for sample in samples: + if not _is_training_eligible(sample): + continue + + label = sample.get("label") + audio_file = sample.get("audio_file") + if not label or not audio_file: + continue + + audio_path = raw_dir / str(audio_file) + if not audio_path.exists(): + continue + + audio, sample_rate = read_audio(audio_path) + if audio.size == 0: + continue + + features.append(extract_feature_vector(audio, sample_rate)) + labels.append(str(label)) + + if len(features) < 2: + raise ValueError("At least 2 labeled samples are required for training") + + if len(set(labels)) < 2: + raise ValueError("Training needs at least 2 distinct labels") + + return np.vstack(features), labels + + +def train_local_model( + samples: list[dict[str, Any]], + raw_dir: Path, + model_path: Path, + test_size: float = 0.25, +) -> TrainingResult: + features, labels = _build_training_matrix(samples, raw_dir) + + encoder = LabelEncoder() + encoded_labels = encoder.fit_transform(labels) + + class_counts = np.bincount(encoded_labels) + can_split = ( + len(encoded_labels) >= 8 + and len(class_counts) >= 2 + and bool(np.all(class_counts >= 2)) + ) + + if can_split: + x_train, x_test, y_train, y_test = train_test_split( + features, + encoded_labels, + test_size=test_size, + random_state=42, + stratify=encoded_labels, + ) + else: + x_train = features + y_train = encoded_labels + x_test = None + y_test = None + + model = RandomForestClassifier( + n_estimators=200, + random_state=42, + class_weight="balanced", + ) + model.fit(x_train, y_train) + + accuracy: float | None = None + if x_test is not None and y_test is not None: + predicted = model.predict(x_test) + accuracy = float(accuracy_score(y_test, predicted)) + + model_path.parent.mkdir(parents=True, exist_ok=True) + joblib.dump({"model": model, "label_encoder": encoder}, model_path) + + return TrainingResult( + model_path=str(model_path), + trained_samples=int(features.shape[0]), + classes=[str(item) for item in encoder.classes_.tolist()], + accuracy=accuracy, + ) + + +def predict_sample( + audio_file: str, + raw_dir: Path, + model_path: Path, +) -> PredictionResult: + if not model_path.exists(): + raise FileNotFoundError(f"Model file was not found: {model_path}") + + model_package = joblib.load(model_path) + model = model_package.get("model") + encoder = model_package.get("label_encoder") + + audio_path = raw_dir / audio_file + if not audio_path.exists(): + raise FileNotFoundError(f"Audio file was not found: {audio_path}") + + audio, sample_rate = read_audio(audio_path) + feature_vector = extract_feature_vector(audio, sample_rate).reshape(1, -1) + + if hasattr(model, "predict_proba"): + probabilities = model.predict_proba(feature_vector)[0] + best_probability_index = int(np.argmax(probabilities)) + encoded_label = int(model.classes_[best_probability_index]) + confidence = float(probabilities[best_probability_index]) + else: + encoded_label = int(model.predict(feature_vector)[0]) + confidence = 1.0 + + predicted_label = str(encoder.inverse_transform([encoded_label])[0]) + return PredictionResult(predicted_label=predicted_label, confidence=confidence) diff --git a/web/app.js b/web/app.js new file mode 100644 index 0000000..5b3b665 --- /dev/null +++ b/web/app.js @@ -0,0 +1,689 @@ +const dom = { + statusBanner: document.getElementById("statusBanner"), + deviceSelect: document.getElementById("deviceSelect"), + durationInput: document.getElementById("durationInput"), + tagInput: document.getElementById("tagInput"), + captureButton: document.getElementById("captureButton"), + trainButton: document.getElementById("trainButton"), + refreshButton: document.getElementById("refreshButton"), + refreshPendingButton: document.getElementById("refreshPendingButton"), + trainingOutput: document.getElementById("trainingOutput"), + samplesList: document.getElementById("samplesList"), + pendingList: document.getElementById("pendingList"), + overnightHours: document.getElementById("overnightHours"), + overnightClipSeconds: document.getElementById("overnightClipSeconds"), + overnightIntervalSeconds: document.getElementById("overnightIntervalSeconds"), + overnightAutoWatch: document.getElementById("overnightAutoWatch"), + overnightStartButton: document.getElementById("overnightStartButton"), + overnightStopButton: document.getElementById("overnightStopButton"), + overnightOutput: document.getElementById("overnightOutput"), +}; + +const LABELS = ["snore", "not_snore", "unclear"]; + +function ensureArray(value) { + return Array.isArray(value) ? value : []; +} + +function numberOr(defaultValue, rawValue) { + const parsed = Number(rawValue); + if (!Number.isFinite(parsed)) { + return defaultValue; + } + return parsed; +} + +function setStatus(message, kind = "ok") { + if (!dom.statusBanner) { + return; + } + + dom.statusBanner.textContent = message; + dom.statusBanner.classList.toggle("status-error", kind === "error"); + dom.statusBanner.classList.toggle("status-ok", kind === "ok"); +} + +async function api(path, options = {}) { + const init = { + ...options, + headers: { + ...(options.headers || {}), + }, + }; + + if (init.body !== undefined && !init.headers["Content-Type"]) { + init.headers["Content-Type"] = "application/json"; + } + + const response = await fetch(path, init); + + let payload = null; + const contentType = response.headers.get("content-type") || ""; + if (contentType.includes("application/json")) { + try { + payload = await response.json(); + } catch { + payload = null; + } + } + + if (!response.ok) { + const detail = payload && payload.detail ? payload.detail : `HTTP ${response.status}`; + throw new Error(detail); + } + + return payload; +} + +function formatDate(isoTimestamp) { + if (!isoTimestamp) { + return "-"; + } + const date = new Date(isoTimestamp); + return date.toLocaleString(); +} + +function formatProposal(sample) { + if (!sample || !sample.proposed_label) { + return "-"; + } + if (typeof sample.proposed_confidence !== "number") { + return String(sample.proposed_label); + } + + const percent = (sample.proposed_confidence * 100).toFixed(1); + return `${sample.proposed_label} (${percent}%)`; +} + +function formatReviewState(sample) { + if (!sample) { + return "not reviewed"; + } + if (sample.review_state === "approved") { + if (sample.proposed_label && sample.label && sample.proposed_label !== sample.label) { + return "corrected + approved"; + } + return "approved for training"; + } + if (sample.review_state === "pending") { + return "pending thumbs"; + } + if (sample.review_state === "rejected") { + return "rejected"; + } + if (sample.review_state === "manual") { + return sample.training_approved ? "manual approved" : "manual not used"; + } + return "not reviewed"; +} + +function buildLabelEditor(sample) { + const wrapper = document.createElement("div"); + wrapper.className = "actions"; + + const select = document.createElement("select"); + select.dataset.action = "label-select"; + select.dataset.sampleId = sample.sample_id; + + for (const label of LABELS) { + const option = document.createElement("option"); + option.value = label; + option.textContent = label; + option.selected = (sample.label || "unclear") === label; + select.appendChild(option); + } + + const saveButton = document.createElement("button"); + saveButton.type = "button"; + saveButton.textContent = "Save"; + saveButton.className = "btn"; + saveButton.dataset.action = "save-label"; + saveButton.dataset.sampleId = sample.sample_id; + + wrapper.append(select, saveButton); + return wrapper; +} + +function buildWatchButton(sample) { + const wrapper = document.createElement("div"); + wrapper.className = "actions"; + + const watchButton = document.createElement("button"); + watchButton.type = "button"; + watchButton.textContent = "Run Watch"; + watchButton.className = "btn"; + watchButton.dataset.action = "propose-watch"; + watchButton.dataset.sampleId = sample.sample_id; + + const output = document.createElement("span"); + output.textContent = formatReviewState(sample); + output.style.fontFamily = "JetBrains Mono, monospace"; + output.style.fontSize = "0.75rem"; + + wrapper.append(watchButton, output); + return wrapper; +} + +function renderSamples(samples) { + if (!dom.samplesList) { + return; + } + + dom.samplesList.innerHTML = ""; + + if (!samples.length) { + const empty = document.createElement("p"); + empty.className = "empty-state"; + empty.textContent = "No samples yet. Record your first sample above."; + dom.samplesList.appendChild(empty); + return; + } + + for (const sample of samples) { + const card = document.createElement("article"); + card.className = "record-card sample-card"; + + const header = document.createElement("div"); + header.className = "record-header"; + + const sampleId = document.createElement("code"); + sampleId.textContent = String(sample.sample_id || "").slice(0, 8); + + const recorded = document.createElement("span"); + recorded.className = "meta-text"; + recorded.textContent = formatDate(sample.created_at); + + const tag = document.createElement("span"); + tag.className = "chip"; + tag.textContent = sample.tag || "untagged"; + + header.append(sampleId, recorded, tag); + + const media = document.createElement("div"); + media.className = "record-media"; + + const listenCell = document.createElement("div"); + const audio = document.createElement("audio"); + audio.controls = true; + audio.preload = "none"; + audio.src = sample.audio_url; + listenCell.appendChild(audio); + + const spectrogramCell = document.createElement("div"); + const image = document.createElement("img"); + image.className = "spectrogram"; + image.src = sample.spectrogram_url; + image.alt = `Spectrogram for ${sample.sample_id}`; + spectrogramCell.appendChild(image); + + media.append(listenCell, spectrogramCell); + + const stats = document.createElement("div"); + stats.className = "record-meta-row"; + + const proposal = document.createElement("span"); + proposal.className = "meta-text"; + proposal.textContent = `Proposal: ${formatProposal(sample)}`; + + const state = document.createElement("span"); + state.className = "meta-text"; + state.textContent = `State: ${formatReviewState(sample)}`; + + stats.append(proposal, state); + + const actions = document.createElement("div"); + actions.className = "record-actions"; + + const labelEditor = buildLabelEditor(sample); + const watchEditor = buildWatchButton(sample); + actions.append(labelEditor, watchEditor); + + card.append(header, media, stats, actions); + dom.samplesList.appendChild(card); + } +} + +function buildPendingReviewCard(sample) { + const card = document.createElement("article"); + card.className = "record-card pending-card"; + + const header = document.createElement("div"); + header.className = "record-header"; + + const code = document.createElement("code"); + code.textContent = String(sample.sample_id || "").slice(0, 8); + + const recorded = document.createElement("span"); + recorded.className = "meta-text"; + recorded.textContent = formatDate(sample.created_at); + + const proposal = document.createElement("span"); + proposal.className = "chip proposal-chip"; + proposal.textContent = formatProposal(sample); + + header.append(code, recorded, proposal); + + const media = document.createElement("div"); + media.className = "record-media"; + + const listenCell = document.createElement("div"); + const audio = document.createElement("audio"); + audio.controls = true; + audio.preload = "none"; + audio.src = sample.audio_url; + listenCell.appendChild(audio); + + const spectrogramCell = document.createElement("div"); + const image = document.createElement("img"); + image.className = "spectrogram"; + image.src = sample.spectrogram_url; + image.alt = `Spectrogram for ${sample.sample_id}`; + spectrogramCell.appendChild(image); + + media.append(listenCell, spectrogramCell); + + const reviewCell = document.createElement("div"); + reviewCell.className = "record-actions"; + const actionGroup = document.createElement("div"); + actionGroup.className = "actions"; + + const thumbsUp = document.createElement("button"); + thumbsUp.type = "button"; + thumbsUp.className = "btn success"; + thumbsUp.textContent = "Thumbs Up"; + thumbsUp.dataset.action = "pending-up"; + thumbsUp.dataset.sampleId = sample.sample_id; + + const thumbsDown = document.createElement("button"); + thumbsDown.type = "button"; + thumbsDown.className = "btn danger"; + thumbsDown.textContent = "Thumbs Down (Invert)"; + thumbsDown.dataset.action = "pending-down"; + thumbsDown.dataset.sampleId = sample.sample_id; + + actionGroup.append(thumbsUp, thumbsDown); + reviewCell.appendChild(actionGroup); + + card.append(header, media, reviewCell); + return card; +} + +function renderPendingReviews(samples) { + if (!dom.pendingList) { + return; + } + + dom.pendingList.innerHTML = ""; + + if (!samples.length) { + const empty = document.createElement("p"); + empty.className = "empty-state"; + empty.textContent = "No pending watch proposals."; + dom.pendingList.appendChild(empty); + return; + } + + for (const sample of samples) { + dom.pendingList.appendChild(buildPendingReviewCard(sample)); + } +} + +function renderOvernightStatus(status) { + if (!dom.overnightOutput || !dom.overnightStartButton || !dom.overnightStopButton) { + return; + } + + dom.overnightStartButton.disabled = Boolean(status && status.running); + dom.overnightStopButton.disabled = !Boolean(status && status.running); + + if (!status || !status.session_id) { + dom.overnightOutput.textContent = "No active overnight session."; + return; + } + + dom.overnightOutput.textContent = [ + `running: ${status.running}`, + `session: ${String(status.session_id).slice(0, 8)}`, + `started: ${formatDate(status.started_at)}`, + `planned_end: ${formatDate(status.planned_end_at)}`, + `next_capture: ${formatDate(status.next_capture_at)}`, + `captured_samples: ${status.captured_samples}`, + `pending_reviews: ${status.pending_reviews}`, + `auto_watch: ${status.auto_watch}`, + `last_error: ${status.last_error || "-"}`, + ].join("\n"); +} + +async function loadDevices() { + if (!dom.deviceSelect) { + return; + } + + try { + const devices = ensureArray(await api("/api/audio/devices")); + dom.deviceSelect.innerHTML = ""; + + if (!devices.length) { + const fallback = document.createElement("option"); + fallback.value = ""; + fallback.textContent = "No input devices found"; + dom.deviceSelect.appendChild(fallback); + return; + } + + for (const device of devices) { + const option = document.createElement("option"); + option.value = String(device.index); + option.textContent = `#${device.index} ${device.name}`; + dom.deviceSelect.appendChild(option); + } + } catch (error) { + dom.deviceSelect.innerHTML = ""; + const fallback = document.createElement("option"); + fallback.value = ""; + fallback.textContent = "Device lookup failed"; + dom.deviceSelect.appendChild(fallback); + setStatus(`Device load failed: ${error.message}`, "error"); + } +} + +async function loadSamples() { + try { + const samples = ensureArray(await api("/api/samples")); + renderSamples(samples); + } catch (error) { + renderSamples([]); + setStatus(`Samples load failed: ${error.message}`, "error"); + } +} + +async function loadPendingReviews() { + try { + const pending = ensureArray(await api("/api/watch/pending")); + renderPendingReviews(pending); + } catch (error) { + renderPendingReviews([]); + setStatus(`Pending queue load failed: ${error.message}`, "error"); + } +} + +async function loadOvernightStatus() { + try { + const status = await api("/api/overnight/status"); + renderOvernightStatus(status); + } catch (error) { + renderOvernightStatus(null); + setStatus(`Overnight status failed: ${error.message}`, "error"); + } +} + +async function refreshMainViews() { + await Promise.allSettled([loadSamples(), loadPendingReviews(), loadOvernightStatus()]); +} + +async function recordSample() { + if (!dom.captureButton || !dom.durationInput || !dom.deviceSelect || !dom.tagInput) { + return; + } + + const payload = { + duration_seconds: numberOr(10, dom.durationInput.value), + device_index: dom.deviceSelect.value === "" ? null : Number(dom.deviceSelect.value), + tag: dom.tagInput.value.trim() || null, + }; + + dom.captureButton.disabled = true; + setStatus("Recording in progress...", "ok"); + + try { + await api("/api/samples/capture", { + method: "POST", + body: JSON.stringify(payload), + }); + dom.tagInput.value = ""; + setStatus("Sample captured successfully", "ok"); + await refreshMainViews(); + } catch (error) { + setStatus(`Recording failed: ${error.message}`, "error"); + } finally { + dom.captureButton.disabled = false; + } +} + +async function saveLabel(sampleId) { + const select = document.querySelector( + `select[data-action='label-select'][data-sample-id='${sampleId}']` + ); + if (!select) { + return; + } + + try { + await api(`/api/samples/${sampleId}/label`, { + method: "POST", + body: JSON.stringify({ label: select.value }), + }); + + setStatus(`Label updated for ${sampleId.slice(0, 8)}`, "ok"); + await refreshMainViews(); + } catch (error) { + setStatus(`Label update failed: ${error.message}`, "error"); + } +} + +async function queueWatchProposal(sampleId) { + try { + await api(`/api/watch/${sampleId}/propose`, { + method: "POST", + }); + + setStatus(`Watch proposal queued for ${sampleId.slice(0, 8)}`, "ok"); + await refreshMainViews(); + } catch (error) { + setStatus(`Watch proposal failed: ${error.message}`, "error"); + } +} + +async function reviewPending(sampleId, approve) { + const endpoint = approve ? "thumbs-up" : "thumbs-down"; + + try { + await api(`/api/watch/${sampleId}/${endpoint}`, { + method: "POST", + }); + + const actionText = approve + ? "Thumbs up approved label" + : "Thumbs down inverted label and approved"; + setStatus(`${actionText} saved for ${sampleId.slice(0, 8)}`, "ok"); + await refreshMainViews(); + } catch (error) { + setStatus(`Review action failed: ${error.message}`, "error"); + } +} + +async function trainModel() { + if (!dom.trainButton || !dom.trainingOutput) { + return; + } + + dom.trainButton.disabled = true; + dom.trainingOutput.textContent = "Training local model..."; + + try { + const result = await api("/api/train", { + method: "POST", + body: JSON.stringify({ test_size: 0.25 }), + }); + + const accuracyText = + result && result.accuracy === null + ? "n/a (dataset too small for holdout)" + : Number(result.accuracy).toFixed(3); + + dom.trainingOutput.textContent = [ + `model: ${result.model_path}`, + `trained_samples: ${result.trained_samples}`, + `classes: ${result.classes.join(", ")}`, + `accuracy: ${accuracyText}`, + ].join("\n"); + + setStatus("Training completed", "ok"); + } catch (error) { + dom.trainingOutput.textContent = `Training failed: ${error.message}`; + setStatus(`Training failed: ${error.message}`, "error"); + } finally { + dom.trainButton.disabled = false; + } +} + +async function startOvernight() { + if ( + !dom.overnightStartButton || + !dom.overnightHours || + !dom.overnightClipSeconds || + !dom.overnightIntervalSeconds || + !dom.deviceSelect || + !dom.overnightAutoWatch + ) { + return; + } + + dom.overnightStartButton.disabled = true; + + const payload = { + duration_hours: numberOr(8, dom.overnightHours.value), + clip_duration_seconds: numberOr(20, dom.overnightClipSeconds.value), + interval_seconds: numberOr(30, dom.overnightIntervalSeconds.value), + device_index: dom.deviceSelect.value === "" ? null : Number(dom.deviceSelect.value), + tag_prefix: "overnight", + auto_watch: Boolean(dom.overnightAutoWatch.checked), + }; + + try { + await api("/api/overnight/start", { + method: "POST", + body: JSON.stringify(payload), + }); + + setStatus("Overnight gathering started", "ok"); + await refreshMainViews(); + } catch (error) { + setStatus(`Overnight start failed: ${error.message}`, "error"); + } finally { + dom.overnightStartButton.disabled = false; + } +} + +async function stopOvernight() { + if (!dom.overnightStopButton) { + return; + } + + dom.overnightStopButton.disabled = true; + + try { + await api("/api/overnight/stop", { + method: "POST", + }); + + setStatus("Overnight gathering stopped", "ok"); + await refreshMainViews(); + } catch (error) { + setStatus(`Overnight stop failed: ${error.message}`, "error"); + } finally { + dom.overnightStopButton.disabled = false; + } +} + +function wireEventHandlers() { + if (dom.samplesList) { + dom.samplesList.addEventListener("click", async (event) => { + const button = event.target.closest("button[data-action]"); + if (!button) { + return; + } + + const action = button.dataset.action; + const sampleId = button.dataset.sampleId; + if (!sampleId) { + return; + } + + if (action === "save-label") { + await saveLabel(sampleId); + } + if (action === "propose-watch") { + await queueWatchProposal(sampleId); + } + }); + } + + if (dom.pendingList) { + dom.pendingList.addEventListener("click", async (event) => { + const button = event.target.closest("button[data-action]"); + if (!button) { + return; + } + + const action = button.dataset.action; + const sampleId = button.dataset.sampleId; + if (!sampleId) { + return; + } + + if (action === "pending-up") { + await reviewPending(sampleId, true); + } + if (action === "pending-down") { + await reviewPending(sampleId, false); + } + }); + } + + if (dom.captureButton) { + dom.captureButton.addEventListener("click", recordSample); + } + if (dom.trainButton) { + dom.trainButton.addEventListener("click", trainModel); + } + if (dom.refreshButton) { + dom.refreshButton.addEventListener("click", loadSamples); + } + if (dom.refreshPendingButton) { + dom.refreshPendingButton.addEventListener("click", loadPendingReviews); + } + if (dom.overnightStartButton) { + dom.overnightStartButton.addEventListener("click", startOvernight); + } + if (dom.overnightStopButton) { + dom.overnightStopButton.addEventListener("click", stopOvernight); + } +} + +async function boot() { + try { + const health = await api("/api/health"); + setStatus(`Server online at ${health.sample_rate}Hz / ${health.channels}ch`, "ok"); + } catch (error) { + setStatus(`Startup error: ${error.message}`, "error"); + } + + await Promise.allSettled([ + loadDevices(), + loadSamples(), + loadPendingReviews(), + loadOvernightStatus(), + ]); + + wireEventHandlers(); + + // Keep overnight status current without requiring manual refreshes. + window.setInterval(() => { + loadOvernightStatus().catch(() => { + // Status errors are handled inside loadOvernightStatus. + }); + }, 15000); +} + +boot(); diff --git a/web/index.html b/web/index.html new file mode 100644 index 0000000..d52edfb --- /dev/null +++ b/web/index.html @@ -0,0 +1,107 @@ + + + + + + SnoreStopper Control Room + + + + + + + +
+
+

Self-hosted sleep signal lab

+

SnoreStopper

+

+ Capture ambient sleep audio, label snoring events, and train a local model from your own + nights. +

+
+ +
+

Capture Sample

+
+ + +
+
+
+ + +
+
+ + +
+
+ +
+ +
+

Overnight Gatherer

+
+
+ + +
+
+ + +
+
+ + +
+
+ +
+ + +
+
No active overnight session.
+
+ +
+

Training

+

+ Label samples as snore, not_snore, or unclear, then + train locally. +

+ +
No training run yet.
+
+ +
+

Status

+

Loading...

+
+ +
+
+

Snore Watch Review Queue

+ +
+
+
+ +
+
+

Samples

+ +
+
+
+
+ + + + diff --git a/web/styles.css b/web/styles.css new file mode 100644 index 0000000..5eaa57d --- /dev/null +++ b/web/styles.css @@ -0,0 +1,383 @@ +:root { + --bg-a: #0b2238; + --bg-b: #153a2e; + --panel: rgba(255, 255, 255, 0.92); + --ink: #112333; + --ink-soft: #35516b; + --accent: #0d8d72; + --accent-strong: #096f5a; + --outline: rgba(17, 35, 51, 0.16); + --danger: #b14d40; + --success: #198754; + --success-strong: #146c43; + --danger-strong: #9f2f24; +} + +* { + box-sizing: border-box; +} + +body { + margin: 0; + font-family: "Chakra Petch", "Segoe UI", sans-serif; + color: var(--ink); + min-height: 100vh; + background: radial-gradient(circle at 20% 10%, #245f7e 0%, transparent 45%), + radial-gradient(circle at 85% 20%, #2c6f47 0%, transparent 50%), + linear-gradient(145deg, var(--bg-a), var(--bg-b)); +} + +.ambient-backdrop { + position: fixed; + inset: 0; + background-image: linear-gradient(to right, rgba(255, 255, 255, 0.04) 1px, transparent 1px), + linear-gradient(to bottom, rgba(255, 255, 255, 0.04) 1px, transparent 1px); + background-size: 22px 22px; + opacity: 0.42; + pointer-events: none; +} + +.layout { + position: relative; + z-index: 1; + width: min(1200px, 92vw); + margin: 2rem auto 3rem; + display: grid; + gap: 1rem; + grid-template-columns: repeat(12, minmax(0, 1fr)); + animation: fade-in 500ms ease; +} + +.hero, +.panel { + background: var(--panel); + border: 1px solid var(--outline); + border-radius: 16px; + box-shadow: 0 18px 30px rgba(9, 20, 28, 0.16); +} + +.hero { + grid-column: 1 / -1; + padding: 1.5rem; +} + +.hero h1 { + margin: 0.3rem 0 0.5rem; + font-size: clamp(2rem, 4vw, 2.8rem); + letter-spacing: 0.03em; +} + +.hero p { + margin: 0; + color: var(--ink-soft); +} + +.eyebrow { + font-family: "JetBrains Mono", monospace; + font-size: 0.9rem; + text-transform: uppercase; + letter-spacing: 0.1em; + color: var(--accent-strong); +} + +.panel { + padding: 1.2rem; +} + +.controls { + grid-column: span 6; +} + +.overnight { + grid-column: span 6; +} + +.training { + grid-column: span 8; +} + +.status { + grid-column: span 4; +} + +.samples { + grid-column: 1 / -1; +} + +.pending { + grid-column: 1 / -1; +} + +h2 { + margin: 0 0 0.8rem; + font-size: 1.2rem; +} + +.row { + display: flex; + flex-direction: column; + gap: 0.4rem; + margin-bottom: 0.85rem; +} + +.split { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 0.8rem; +} + +.overnight-grid { + grid-template-columns: repeat(3, 1fr); +} + +.toggle-row { + display: flex; + align-items: center; + gap: 0.5rem; + margin: 0.45rem 0 0.9rem; + font-size: 0.92rem; + color: var(--ink-soft); +} + +.toggle-row input { + width: auto; +} + +label { + font-family: "JetBrains Mono", monospace; + font-size: 0.8rem; + letter-spacing: 0.02em; +} + +input, +select, +button { + font: inherit; +} + +input, +select { + width: 100%; + border: 1px solid var(--outline); + border-radius: 10px; + padding: 0.55rem 0.7rem; + background: #ffffff; +} + +.btn { + border: 1px solid var(--outline); + border-radius: 10px; + padding: 0.55rem 0.8rem; + background: #ecf3f7; + color: var(--ink); + cursor: pointer; + transition: transform 120ms ease, background 120ms ease; +} + +.btn:hover { + transform: translateY(-1px); + background: #dfeaf1; +} + +.btn.primary { + background: var(--accent); + border-color: var(--accent); + color: #ffffff; +} + +.btn.primary:hover { + background: var(--accent-strong); +} + +.btn.success { + background: var(--success); + border-color: var(--success); + color: #ffffff; +} + +.btn.success:hover { + background: var(--success-strong); +} + +.btn.danger { + background: #c0392b; + border-color: #c0392b; + color: #ffffff; +} + +.btn.danger:hover { + background: var(--danger-strong); +} + +.btn.ghost { + background: transparent; +} + +.output { + margin: 0.8rem 0 0; + padding: 0.8rem; + border-radius: 10px; + border: 1px solid var(--outline); + background: #f4f8fb; + color: #1b384f; + min-height: 72px; + white-space: pre-wrap; + overflow-wrap: anywhere; + font-family: "JetBrains Mono", monospace; + font-size: 0.82rem; +} + +.samples-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 0.8rem; +} + +.sample-list, +.pending-list { + display: grid; + gap: 0.85rem; +} + +.record-card { + border: 1px solid var(--outline); + border-radius: 12px; + background: #ffffff; + padding: 0.9rem; + display: grid; + gap: 0.8rem; +} + +.record-header { + display: flex; + align-items: center; + flex-wrap: wrap; + gap: 0.55rem; +} + +.record-header code { + font-family: "JetBrains Mono", monospace; + font-size: 0.78rem; + background: #eff5f9; + border-radius: 8px; + padding: 0.2rem 0.45rem; +} + +.meta-text { + color: var(--ink-soft); + font-size: 0.85rem; +} + +.chip { + border: 1px solid var(--outline); + border-radius: 999px; + background: #f3f8fb; + color: #23475f; + padding: 0.2rem 0.55rem; + font-size: 0.78rem; + line-height: 1.2; +} + +.proposal-chip { + font-family: "JetBrains Mono", monospace; +} + +.record-media { + display: grid; + grid-template-columns: minmax(180px, 230px) minmax(180px, 240px); + gap: 0.8rem; + align-items: start; +} + +.record-meta-row { + display: flex; + flex-wrap: wrap; + gap: 0.7rem; +} + +.record-actions { + display: flex; + flex-wrap: wrap; + gap: 0.6rem; + align-items: center; +} + +.record-actions .actions { + border: 1px solid var(--outline); + border-radius: 10px; + background: #fbfdff; + padding: 0.45rem; +} + +img.spectrogram { + width: min(220px, 100%); + height: auto; + border-radius: 8px; + border: 1px solid var(--outline); +} + +audio { + width: min(230px, 100%); +} + +.actions { + display: flex; + align-items: center; + flex-wrap: wrap; + gap: 0.5rem; +} + +.empty-state { + margin: 0; + border: 1px dashed var(--outline); + border-radius: 10px; + background: #f8fbfd; + color: var(--ink-soft); + padding: 0.9rem; +} + +.status-error { + color: var(--danger); +} + +.status-ok { + color: var(--accent-strong); +} + +@keyframes fade-in { + from { + opacity: 0; + transform: translateY(8px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +@media (max-width: 980px) { + .controls, + .overnight, + .training, + .status { + grid-column: 1 / -1; + } + + .split { + grid-template-columns: 1fr; + } + + .overnight-grid { + grid-template-columns: 1fr; + } + + .record-media { + grid-template-columns: 1fr; + } + + img.spectrogram, + audio { + width: 100%; + } +}