feat: add storage and training modules for snore detection

- Implemented `storage.py` for managing metadata storage, including sample addition, retrieval, and review state management.
- Created `training.py` for training a local model using Random Forest, including functions for training and predicting samples.
- Developed a web interface in `app.js` for capturing audio samples, managing labels, and training the model.
- Added HTML structure in `index.html` for the SnoreStopper control room with sections for sample capture, overnight gathering, training, and status display.
- Styled the application with `styles.css` to enhance user experience and interface aesthetics.
This commit is contained in:
2026-03-12 13:35:17 -04:00
commit 28012e70e0
21 changed files with 2680 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
"""SnoreStopper package."""
__all__ = ["__version__"]
__version__ = "0.1.0"

68
src/snorestopper/audio.py Normal file
View File

@@ -0,0 +1,68 @@
from __future__ import annotations
from typing import Any
import numpy as np
try:
import sounddevice as sd
except Exception:
sd = None
def _require_sounddevice() -> None:
if sd is None:
raise RuntimeError(
"sounddevice is unavailable. Install requirements and ensure PortAudio is available."
)
def list_input_devices() -> list[dict[str, Any]]:
_require_sounddevice()
devices = sd.query_devices()
input_devices: list[dict[str, Any]] = []
for index, device in enumerate(devices):
max_input_channels = int(device.get("max_input_channels", 0))
if max_input_channels <= 0:
continue
input_devices.append(
{
"index": index,
"name": str(device.get("name", f"Input Device {index}")),
"max_input_channels": max_input_channels,
"default_samplerate": float(device.get("default_samplerate", 0.0)),
}
)
return input_devices
def capture_audio(
duration_seconds: int,
sample_rate: int,
channels: int,
device_index: int | None = None,
) -> np.ndarray:
_require_sounddevice()
if duration_seconds <= 0:
raise RuntimeError("duration_seconds must be positive")
frame_count = int(duration_seconds * sample_rate)
recording = sd.rec(
frame_count,
samplerate=sample_rate,
channels=channels,
dtype="float32",
device=device_index,
)
sd.wait()
audio = np.asarray(recording, dtype=np.float32)
if audio.ndim == 2:
# Collapse to mono for simpler feature extraction and training.
audio = np.mean(audio, axis=1)
return audio

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import os
@dataclass(frozen=True)
class Settings:
project_root: Path
data_dir: Path
raw_dir: Path
spectrogram_dir: Path
models_dir: Path
metadata_dir: Path
web_dir: Path
sample_rate: int
channels: int
min_duration_seconds: int
max_duration_seconds: int
model_file_name: str
def load_settings() -> Settings:
project_root = Path(
os.getenv("SNORESTOPPER_ROOT", Path(__file__).resolve().parents[2])
)
data_dir = project_root / "data"
return Settings(
project_root=project_root,
data_dir=data_dir,
raw_dir=data_dir / "raw",
spectrogram_dir=data_dir / "spectrograms",
models_dir=data_dir / "models",
metadata_dir=data_dir / "meta",
web_dir=project_root / "web",
sample_rate=int(os.getenv("SNORESTOPPER_SAMPLE_RATE", "16000")),
channels=int(os.getenv("SNORESTOPPER_CHANNELS", "1")),
min_duration_seconds=int(os.getenv("SNORESTOPPER_MIN_DURATION", "2")),
max_duration_seconds=int(os.getenv("SNORESTOPPER_MAX_DURATION", "90")),
model_file_name=os.getenv("SNORESTOPPER_MODEL_FILE", "snore_classifier.joblib"),
)

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
from scipy import signal
def save_audio_as_flac(audio: np.ndarray, sample_rate: int, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(file=str(output_path), data=audio, samplerate=sample_rate, format="FLAC")
def read_audio(audio_path: Path) -> tuple[np.ndarray, int]:
audio, sample_rate = sf.read(str(audio_path), always_2d=False)
waveform = np.asarray(audio, dtype=np.float32)
if waveform.ndim == 2:
waveform = np.mean(waveform, axis=1)
return waveform, int(sample_rate)
def _compute_spectrogram(audio: np.ndarray, sample_rate: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
n_samples = int(audio.shape[0])
if n_samples <= 1:
nperseg = 1
noverlap = 0
else:
nperseg = min(1024, n_samples)
noverlap = min(int(nperseg * 0.5), nperseg - 1)
return signal.spectrogram(
audio,
fs=sample_rate,
nperseg=nperseg,
noverlap=noverlap,
scaling="spectrum",
mode="magnitude",
)
def save_spectrogram(audio: np.ndarray, sample_rate: int, output_path: Path) -> None:
frequencies, times, spectrum = _compute_spectrogram(audio, sample_rate)
power_db = 10.0 * np.log10(spectrum + 1e-12)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig, axis = plt.subplots(figsize=(8, 3))
mesh = axis.pcolormesh(times, frequencies, power_db, shading="gouraud", cmap="magma")
axis.set_title("SnoreStopper Spectrogram")
axis.set_ylabel("Frequency (Hz)")
axis.set_xlabel("Time (s)")
fig.colorbar(mesh, ax=axis, label="Power (dB)")
fig.tight_layout()
fig.savefig(output_path, dpi=120)
plt.close(fig)
def _band_energy(power: np.ndarray, frequencies: np.ndarray, low: float, high: float) -> float:
mask = (frequencies >= low) & (frequencies < high)
if not np.any(mask):
return 0.0
return float(np.mean(power[mask, :]))
def extract_feature_vector(audio: np.ndarray, sample_rate: int) -> np.ndarray:
frequencies, _, spectrum = _compute_spectrogram(audio, sample_rate)
power = np.log1p(spectrum)
mean_power = float(np.mean(power))
std_power = float(np.std(power))
max_power = float(np.max(power))
low_band = _band_energy(power, frequencies, 20.0, 250.0)
mid_band = _band_energy(power, frequencies, 250.0, 1000.0)
high_band = _band_energy(power, frequencies, 1000.0, 4000.0)
spectral_weights = np.mean(power, axis=1) + 1e-9
spectral_centroid = float(np.average(frequencies, weights=spectral_weights))
return np.asarray(
[
mean_power,
std_power,
max_power,
low_band,
mid_band,
high_band,
spectral_centroid,
],
dtype=np.float32,
)

555
src/snorestopper/main.py Normal file
View File

@@ -0,0 +1,555 @@
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from .audio import capture_audio, list_input_devices
from .config import Settings, load_settings
from .features import save_audio_as_flac, save_spectrogram
from .overnight import OvernightCaptureController
from .schemas import (
ApiMessage,
CaptureRequest,
DetectRequest,
DetectResponse,
DeviceInfo,
LabelRequest,
OvernightStartRequest,
OvernightStatusResponse,
SampleRecord,
TrainRequest,
TrainResponse,
)
from .storage import MetadataStore, ensure_directories
from .training import predict_sample, train_local_model
settings = load_settings()
metadata_file = ensure_directories(settings)
store = MetadataStore(metadata_file)
BINARY_LABELS = {"snore", "not_snore"}
VALID_SAMPLE_LABELS = {"snore", "not_snore", "unclear"}
REVIEW_STATES = {"none", "pending", "approved", "rejected", "manual"}
LABEL_SOURCES = {"none", "manual", "watch"}
app = FastAPI(
title="SnoreStopper",
version="0.1.0",
description="Self-hosted snore tracking, labeling, and model training",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/web", StaticFiles(directory=settings.web_dir), name="web")
app.mount("/data", StaticFiles(directory=settings.data_dir), name="data")
def _to_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _to_optional_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _invert_binary_label(label: str) -> str:
if label == "snore":
return "not_snore"
if label == "not_snore":
return "snore"
raise ValueError(f"Unsupported binary label: {label}")
def _to_sample_record(record: dict[str, Any]) -> SampleRecord:
audio_file = str(record.get("audio_file", ""))
spectrogram_file = str(record.get("spectrogram_file", ""))
label = record.get("label")
if label not in VALID_SAMPLE_LABELS:
label = None
proposed_label = record.get("proposed_label")
if proposed_label not in BINARY_LABELS:
proposed_label = None
review_state = str(record.get("review_state", "none"))
if review_state not in REVIEW_STATES:
review_state = "none"
label_source = str(record.get("label_source", "none"))
if label_source not in LABEL_SOURCES:
label_source = "none"
return SampleRecord(
sample_id=str(record.get("sample_id", "")),
created_at=str(record.get("created_at", "")),
duration_seconds=_to_int(record.get("duration_seconds"), default=0),
sample_rate=_to_int(record.get("sample_rate"), default=settings.sample_rate),
device_index=record.get("device_index"),
tag=record.get("tag"),
label=label,
audio_file=audio_file,
spectrogram_file=spectrogram_file,
audio_url=f"/data/raw/{audio_file}",
spectrogram_url=f"/data/spectrograms/{spectrogram_file}",
proposed_label=proposed_label,
proposed_confidence=_to_optional_float(record.get("proposed_confidence")),
review_state=review_state,
training_approved=bool(record.get("training_approved", False)),
label_source=label_source,
)
def _resolve_model_path(model_path_value: str | None) -> Path:
if not model_path_value:
return settings.models_dir / settings.model_file_name
candidate = Path(model_path_value).expanduser()
if not candidate.is_absolute():
candidate = settings.project_root / candidate
return candidate
def _validate_capture_duration_seconds(duration_seconds: int, app_settings: Settings) -> None:
if duration_seconds < app_settings.min_duration_seconds:
raise HTTPException(
status_code=400,
detail=(
f"duration_seconds must be at least {app_settings.min_duration_seconds} "
"seconds"
),
)
if duration_seconds > app_settings.max_duration_seconds:
raise HTTPException(
status_code=400,
detail=(
f"duration_seconds must be {app_settings.max_duration_seconds} "
"seconds or less"
),
)
def _capture_and_store_sample(
*,
duration_seconds: int,
device_index: int | None,
tag: str | None,
watch_after_capture: bool,
) -> dict[str, Any]:
_validate_capture_duration_seconds(duration_seconds, settings)
try:
waveform = capture_audio(
duration_seconds=duration_seconds,
sample_rate=settings.sample_rate,
channels=settings.channels,
device_index=device_index,
)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Recording failed: {exc}") from exc
sample_id = uuid4().hex
audio_file = f"{sample_id}.flac"
spectrogram_file = f"{sample_id}.png"
audio_path = settings.raw_dir / audio_file
spectrogram_path = settings.spectrogram_dir / spectrogram_file
try:
save_audio_as_flac(waveform, settings.sample_rate, audio_path)
save_spectrogram(waveform, settings.sample_rate, spectrogram_path)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Saving sample failed: {exc}") from exc
record: dict[str, Any] = {
"sample_id": sample_id,
"created_at": datetime.now(tz=timezone.utc).isoformat(),
"duration_seconds": duration_seconds,
"sample_rate": settings.sample_rate,
"device_index": device_index,
"tag": tag,
"label": None,
"audio_file": audio_file,
"spectrogram_file": spectrogram_file,
"proposed_label": None,
"proposed_confidence": None,
"review_state": "none",
"training_approved": False,
"label_source": "none",
}
if watch_after_capture:
model_path = settings.models_dir / settings.model_file_name
if model_path.exists():
try:
prediction = predict_sample(
audio_file=audio_file,
raw_dir=settings.raw_dir,
model_path=model_path,
)
if prediction.predicted_label in BINARY_LABELS:
record.update(
{
"proposed_label": prediction.predicted_label,
"proposed_confidence": prediction.confidence,
"review_state": "pending",
"training_approved": False,
"label_source": "watch",
}
)
except Exception:
# Capture should still be retained even when watch inference fails.
pass
store.add_sample(record)
return record
def _background_capture_job(payload: dict[str, Any]) -> dict[str, Any]:
return _capture_and_store_sample(
duration_seconds=_to_int(payload.get("duration_seconds"), default=10),
device_index=payload.get("device_index"),
tag=str(payload.get("tag")) if payload.get("tag") is not None else None,
watch_after_capture=bool(payload.get("watch_after_capture", False)),
)
overnight_controller = OvernightCaptureController(
run_capture_job=_background_capture_job,
pending_review_count=store.count_pending_review_samples,
)
@app.get("/", include_in_schema=False)
def root_redirect() -> RedirectResponse:
return RedirectResponse(url="/web/index.html")
@app.get("/api/health")
def health() -> dict[str, Any]:
return {
"status": "ok",
"project_root": str(settings.project_root),
"sample_rate": settings.sample_rate,
"channels": settings.channels,
}
@app.get("/api/audio/devices", response_model=list[DeviceInfo])
def get_audio_devices() -> list[DeviceInfo]:
try:
devices = list_input_devices()
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Unable to query input devices: {exc}") from exc
return [DeviceInfo(**item) for item in devices]
@app.get("/api/samples", response_model=list[SampleRecord])
def list_samples() -> list[SampleRecord]:
return [_to_sample_record(item) for item in store.list_samples()]
@app.post("/api/samples/capture", response_model=SampleRecord)
def capture_sample(payload: CaptureRequest) -> SampleRecord:
record = _capture_and_store_sample(
duration_seconds=payload.duration_seconds,
device_index=payload.device_index,
tag=payload.tag,
watch_after_capture=payload.watch_after_capture,
)
return _to_sample_record(record)
@app.post("/api/samples/{sample_id}/label", response_model=SampleRecord)
def label_sample(sample_id: str, payload: LabelRequest) -> SampleRecord:
training_approved = payload.label in BINARY_LABELS
try:
updated = store.update_sample_fields(
sample_id=sample_id,
updates={
"label": payload.label,
"review_state": "manual",
"training_approved": training_approved,
"label_source": "manual",
"proposed_label": None,
"proposed_confidence": None,
},
)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return _to_sample_record(updated)
@app.post("/api/watch/{sample_id}/propose", response_model=SampleRecord)
def propose_watch_label(sample_id: str) -> SampleRecord:
sample = store.get_sample(sample_id)
if sample is None:
raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found")
if bool(sample.get("training_approved", False)):
raise HTTPException(
status_code=409,
detail="Sample is already approved for training. Clear or relabel it before watch proposal.",
)
model_path = settings.models_dir / settings.model_file_name
if not model_path.exists():
raise HTTPException(
status_code=404,
detail=(
"No trained model found. Train a model first, then run snore watch proposals."
),
)
try:
prediction = predict_sample(
audio_file=str(sample.get("audio_file", "")),
raw_dir=settings.raw_dir,
model_path=model_path,
)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Watch proposal failed: {exc}") from exc
if prediction.predicted_label not in BINARY_LABELS:
raise HTTPException(
status_code=500,
detail=(
"Model predicted an unsupported class. Retrain with labels 'snore' and 'not_snore'."
),
)
try:
updated = store.update_sample_fields(
sample_id=sample_id,
updates={
"label": None,
"proposed_label": prediction.predicted_label,
"proposed_confidence": prediction.confidence,
"review_state": "pending",
"training_approved": False,
"label_source": "watch",
},
)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return _to_sample_record(updated)
@app.get("/api/watch/pending", response_model=list[SampleRecord])
def list_pending_watch_reviews() -> list[SampleRecord]:
pending_samples = store.list_pending_review_samples()
return [_to_sample_record(item) for item in pending_samples]
@app.post("/api/watch/{sample_id}/thumbs-up", response_model=SampleRecord)
def approve_watch_prediction(sample_id: str) -> SampleRecord:
sample = store.get_sample(sample_id)
if sample is None:
raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found")
proposed_label = sample.get("proposed_label")
if proposed_label not in BINARY_LABELS:
raise HTTPException(
status_code=400,
detail="Sample has no pending watch proposal to approve",
)
try:
updated = store.update_sample_fields(
sample_id=sample_id,
updates={
"label": proposed_label,
"review_state": "approved",
"training_approved": True,
"label_source": "watch",
},
)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return _to_sample_record(updated)
@app.post("/api/watch/{sample_id}/thumbs-down", response_model=SampleRecord)
def reject_watch_prediction(sample_id: str) -> SampleRecord:
sample = store.get_sample(sample_id)
if sample is None:
raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found")
if sample.get("review_state") != "pending":
raise HTTPException(
status_code=400,
detail="Sample is not currently pending watch review",
)
proposed_label = sample.get("proposed_label")
if proposed_label not in BINARY_LABELS:
raise HTTPException(
status_code=400,
detail="Sample has no pending watch proposal to invert",
)
try:
corrected_label = _invert_binary_label(str(proposed_label))
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
try:
updated = store.update_sample_fields(
sample_id=sample_id,
updates={
"label": corrected_label,
"review_state": "approved",
"training_approved": True,
"label_source": "watch",
},
)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return _to_sample_record(updated)
@app.post("/api/train", response_model=TrainResponse)
def train_model(payload: TrainRequest) -> TrainResponse:
model_path = settings.models_dir / settings.model_file_name
samples = store.list_samples()
try:
result = train_local_model(
samples=samples,
raw_dir=settings.raw_dir,
model_path=model_path,
test_size=payload.test_size,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Training failed: {exc}") from exc
return TrainResponse(
model_path=result.model_path,
trained_samples=result.trained_samples,
classes=result.classes,
accuracy=result.accuracy,
)
@app.get("/api/overnight/status", response_model=OvernightStatusResponse)
def overnight_status() -> OvernightStatusResponse:
return OvernightStatusResponse(**overnight_controller.get_status())
@app.post("/api/overnight/start", response_model=OvernightStatusResponse)
def start_overnight_capture(payload: OvernightStartRequest) -> OvernightStatusResponse:
_validate_capture_duration_seconds(payload.clip_duration_seconds, settings)
if payload.interval_seconds < payload.clip_duration_seconds:
raise HTTPException(
status_code=400,
detail="interval_seconds must be greater than or equal to clip_duration_seconds",
)
if payload.auto_watch:
model_path = settings.models_dir / settings.model_file_name
if not model_path.exists():
raise HTTPException(
status_code=400,
detail=(
"auto_watch is enabled but no trained model was found. Train a model "
"first or disable auto_watch."
),
)
try:
status = overnight_controller.start_session(
duration_hours=payload.duration_hours,
clip_duration_seconds=payload.clip_duration_seconds,
interval_seconds=payload.interval_seconds,
device_index=payload.device_index,
tag_prefix=payload.tag_prefix,
auto_watch=payload.auto_watch,
)
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
return OvernightStatusResponse(**status)
@app.post("/api/overnight/stop", response_model=OvernightStatusResponse)
def stop_overnight_capture() -> OvernightStatusResponse:
return OvernightStatusResponse(**overnight_controller.stop_session())
@app.post("/api/detect", response_model=DetectResponse)
def detect_sample(payload: DetectRequest) -> DetectResponse:
sample = store.get_sample(payload.sample_id)
if sample is None:
raise HTTPException(status_code=404, detail=f"Sample '{payload.sample_id}' was not found")
model_path = _resolve_model_path(payload.model_path)
try:
prediction = predict_sample(
audio_file=str(sample["audio_file"]),
raw_dir=settings.raw_dir,
model_path=model_path,
)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Detection failed: {exc}") from exc
return DetectResponse(
sample_id=payload.sample_id,
predicted_label=prediction.predicted_label,
confidence=prediction.confidence,
model_path=str(model_path),
)
@app.post("/api/reset", response_model=ApiMessage)
def reset_sample_labels() -> ApiMessage:
# Keep recordings but reset labels to quickly restart training experiments.
samples = store.list_samples()
for sample in samples:
if sample.get("sample_id"):
store.update_sample_fields(
sample_id=str(sample["sample_id"]),
updates={
"label": None,
"proposed_label": None,
"proposed_confidence": None,
"review_state": "none",
"training_approved": False,
"label_source": "none",
},
)
return ApiMessage(message="All sample labels and watch proposals were reset")

View File

@@ -0,0 +1,167 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from threading import Event, Lock, Thread
from typing import Any, Callable
from uuid import uuid4
class OvernightCaptureController:
def __init__(
self,
run_capture_job: Callable[[dict[str, Any]], dict[str, Any]],
pending_review_count: Callable[[], int],
):
self._run_capture_job = run_capture_job
self._pending_review_count = pending_review_count
self._lock = Lock()
self._stop_event = Event()
self._thread: Thread | None = None
self._state: dict[str, Any] = {
"running": False,
"session_id": None,
"started_at": None,
"planned_end_at": None,
"next_capture_at": None,
"duration_hours": None,
"clip_duration_seconds": None,
"interval_seconds": None,
"auto_watch": None,
"captured_samples": 0,
"last_error": None,
}
def start_session(
self,
*,
duration_hours: int,
clip_duration_seconds: int,
interval_seconds: int,
device_index: int | None,
tag_prefix: str,
auto_watch: bool,
) -> dict[str, Any]:
with self._lock:
if self._state["running"]:
raise RuntimeError("Overnight capture is already running")
session_id = uuid4().hex
started_at = datetime.now(tz=timezone.utc)
planned_end_at = started_at + timedelta(hours=duration_hours)
self._stop_event = Event()
self._state = {
"running": True,
"session_id": session_id,
"started_at": started_at.isoformat(),
"planned_end_at": planned_end_at.isoformat(),
"next_capture_at": started_at.isoformat(),
"duration_hours": duration_hours,
"clip_duration_seconds": clip_duration_seconds,
"interval_seconds": interval_seconds,
"auto_watch": auto_watch,
"captured_samples": 0,
"last_error": None,
}
worker = Thread(
target=self._run_loop,
kwargs={
"session_id": session_id,
"planned_end_at": planned_end_at,
"clip_duration_seconds": clip_duration_seconds,
"interval_seconds": interval_seconds,
"device_index": device_index,
"tag_prefix": tag_prefix,
"auto_watch": auto_watch,
"stop_event": self._stop_event,
},
daemon=True,
name="snorestopper-overnight-capture",
)
self._thread = worker
worker.start()
return self._status_locked()
def stop_session(self) -> dict[str, Any]:
thread_to_join: Thread | None = None
with self._lock:
if not self._state["running"]:
return self._status_locked()
self._stop_event.set()
thread_to_join = self._thread
if thread_to_join is not None:
thread_to_join.join(timeout=2.0)
with self._lock:
return self._status_locked()
def get_status(self) -> dict[str, Any]:
with self._lock:
return self._status_locked()
def _status_locked(self) -> dict[str, Any]:
status = dict(self._state)
status["pending_reviews"] = int(self._pending_review_count())
return status
def _run_loop(
self,
*,
session_id: str,
planned_end_at: datetime,
clip_duration_seconds: int,
interval_seconds: int,
device_index: int | None,
tag_prefix: str,
auto_watch: bool,
stop_event: Event,
) -> None:
sample_number = 0
next_capture_at = datetime.now(tz=timezone.utc)
try:
while not stop_event.is_set():
now = datetime.now(tz=timezone.utc)
if now >= planned_end_at:
break
wait_seconds = (next_capture_at - now).total_seconds()
if wait_seconds > 0:
stop_event.wait(wait_seconds)
continue
sample_number += 1
capture_started_at = datetime.now(tz=timezone.utc)
tag = f"{tag_prefix}-{session_id[:6]}-{sample_number:04d}"
try:
self._run_capture_job(
{
"duration_seconds": clip_duration_seconds,
"device_index": device_index,
"tag": tag,
"watch_after_capture": auto_watch,
}
)
with self._lock:
self._state["captured_samples"] = sample_number
self._state["last_error"] = None
except Exception as exc:
with self._lock:
self._state["last_error"] = str(exc)
next_capture_at = capture_started_at + timedelta(seconds=interval_seconds)
with self._lock:
if next_capture_at < planned_end_at:
self._state["next_capture_at"] = next_capture_at.isoformat()
else:
self._state["next_capture_at"] = None
finally:
with self._lock:
self._state["running"] = False
self._state["next_capture_at"] = None
self._thread = None

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
SampleLabel = Literal["snore", "not_snore", "unclear"]
BinaryLabel = Literal["snore", "not_snore"]
ReviewState = Literal["none", "pending", "approved", "rejected", "manual"]
LabelSource = Literal["none", "manual", "watch"]
class DeviceInfo(BaseModel):
index: int
name: str
max_input_channels: int
default_samplerate: float
class CaptureRequest(BaseModel):
duration_seconds: int = Field(default=10, ge=1, le=600)
device_index: int | None = None
tag: str | None = Field(default=None, max_length=80)
watch_after_capture: bool = False
class SampleRecord(BaseModel):
sample_id: str
created_at: str
duration_seconds: int
sample_rate: int
device_index: int | None = None
tag: str | None = None
label: SampleLabel | None = None
audio_file: str
spectrogram_file: str
audio_url: str
spectrogram_url: str
proposed_label: BinaryLabel | None = None
proposed_confidence: float | None = None
review_state: ReviewState = "none"
training_approved: bool = False
label_source: LabelSource = "none"
class LabelRequest(BaseModel):
label: SampleLabel
class TrainRequest(BaseModel):
test_size: float = Field(default=0.25, ge=0.1, le=0.5)
class TrainResponse(BaseModel):
model_path: str
trained_samples: int
classes: list[str]
accuracy: float | None = None
class DetectRequest(BaseModel):
sample_id: str
model_path: str | None = None
class DetectResponse(BaseModel):
sample_id: str
predicted_label: str
confidence: float
model_path: str
class OvernightStartRequest(BaseModel):
duration_hours: int = Field(default=8, ge=1, le=12)
clip_duration_seconds: int = Field(default=20, ge=2, le=600)
interval_seconds: int = Field(default=30, ge=2, le=600)
device_index: int | None = None
tag_prefix: str = Field(default="overnight", max_length=80)
auto_watch: bool = True
class OvernightStatusResponse(BaseModel):
running: bool
session_id: str | None = None
started_at: str | None = None
planned_end_at: str | None = None
next_capture_at: str | None = None
duration_hours: int | None = None
clip_duration_seconds: int | None = None
interval_seconds: int | None = None
auto_watch: bool | None = None
captured_samples: int = 0
pending_reviews: int = 0
last_error: str | None = None
class ApiMessage(BaseModel):
message: str

119
src/snorestopper/storage.py Normal file
View File

@@ -0,0 +1,119 @@
from __future__ import annotations
import json
from pathlib import Path
from threading import Lock
from typing import Any
from .config import Settings
def ensure_directories(settings: Settings) -> Path:
for path in (
settings.data_dir,
settings.raw_dir,
settings.spectrogram_dir,
settings.models_dir,
settings.metadata_dir,
):
path.mkdir(parents=True, exist_ok=True)
metadata_file = settings.metadata_dir / "samples.json"
if not metadata_file.exists():
metadata_file.write_text("[]", encoding="utf-8")
return metadata_file
class MetadataStore:
def __init__(self, metadata_file: Path):
self.metadata_file = metadata_file
self._lock = Lock()
self.metadata_file.parent.mkdir(parents=True, exist_ok=True)
if not self.metadata_file.exists():
self.metadata_file.write_text("[]", encoding="utf-8")
def _read(self) -> list[dict[str, Any]]:
if not self.metadata_file.exists():
return []
raw = self.metadata_file.read_text(encoding="utf-8").strip()
if not raw:
return []
try:
payload = json.loads(raw)
except json.JSONDecodeError:
return []
if not isinstance(payload, list):
return []
return [item for item in payload if isinstance(item, dict)]
def _write(self, records: list[dict[str, Any]]) -> None:
self.metadata_file.write_text(
json.dumps(records, indent=2, ensure_ascii=True),
encoding="utf-8",
)
def list_samples(self) -> list[dict[str, Any]]:
with self._lock:
records = self._read()
return sorted(records, key=lambda item: item.get("created_at", ""), reverse=True)
def get_sample(self, sample_id: str) -> dict[str, Any] | None:
with self._lock:
records = self._read()
for record in records:
if record.get("sample_id") == sample_id:
return record
return None
def add_sample(self, record: dict[str, Any]) -> dict[str, Any]:
with self._lock:
records = self._read()
records.append(record)
self._write(records)
return record
def update_label(self, sample_id: str, label: str) -> dict[str, Any]:
with self._lock:
records = self._read()
for record in records:
if record.get("sample_id") == sample_id:
record["label"] = label
self._write(records)
return record
raise KeyError(f"Sample '{sample_id}' was not found")
def update_sample_fields(
self,
sample_id: str,
updates: dict[str, Any],
) -> dict[str, Any]:
with self._lock:
records = self._read()
for record in records:
if record.get("sample_id") == sample_id:
record.update(updates)
self._write(records)
return record
raise KeyError(f"Sample '{sample_id}' was not found")
def list_pending_review_samples(self) -> list[dict[str, Any]]:
with self._lock:
records = self._read()
pending = [item for item in records if item.get("review_state") == "pending"]
return sorted(pending, key=lambda item: item.get("created_at", ""), reverse=True)
def count_pending_review_samples(self) -> int:
with self._lock:
records = self._read()
return sum(1 for item in records if item.get("review_state") == "pending")

View File

@@ -0,0 +1,164 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from .features import extract_feature_vector, read_audio
def _is_training_eligible(sample: dict[str, Any]) -> bool:
label = sample.get("label")
if label not in {"snore", "not_snore"}:
return False
approved_flag = sample.get("training_approved")
if approved_flag is None:
# Backward compatibility for samples created before review gating existed.
return True
return bool(approved_flag)
@dataclass(frozen=True)
class TrainingResult:
model_path: str
trained_samples: int
classes: list[str]
accuracy: float | None
@dataclass(frozen=True)
class PredictionResult:
predicted_label: str
confidence: float
def _build_training_matrix(
samples: list[dict[str, Any]],
raw_dir: Path,
) -> tuple[np.ndarray, list[str]]:
features: list[np.ndarray] = []
labels: list[str] = []
for sample in samples:
if not _is_training_eligible(sample):
continue
label = sample.get("label")
audio_file = sample.get("audio_file")
if not label or not audio_file:
continue
audio_path = raw_dir / str(audio_file)
if not audio_path.exists():
continue
audio, sample_rate = read_audio(audio_path)
if audio.size == 0:
continue
features.append(extract_feature_vector(audio, sample_rate))
labels.append(str(label))
if len(features) < 2:
raise ValueError("At least 2 labeled samples are required for training")
if len(set(labels)) < 2:
raise ValueError("Training needs at least 2 distinct labels")
return np.vstack(features), labels
def train_local_model(
samples: list[dict[str, Any]],
raw_dir: Path,
model_path: Path,
test_size: float = 0.25,
) -> TrainingResult:
features, labels = _build_training_matrix(samples, raw_dir)
encoder = LabelEncoder()
encoded_labels = encoder.fit_transform(labels)
class_counts = np.bincount(encoded_labels)
can_split = (
len(encoded_labels) >= 8
and len(class_counts) >= 2
and bool(np.all(class_counts >= 2))
)
if can_split:
x_train, x_test, y_train, y_test = train_test_split(
features,
encoded_labels,
test_size=test_size,
random_state=42,
stratify=encoded_labels,
)
else:
x_train = features
y_train = encoded_labels
x_test = None
y_test = None
model = RandomForestClassifier(
n_estimators=200,
random_state=42,
class_weight="balanced",
)
model.fit(x_train, y_train)
accuracy: float | None = None
if x_test is not None and y_test is not None:
predicted = model.predict(x_test)
accuracy = float(accuracy_score(y_test, predicted))
model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump({"model": model, "label_encoder": encoder}, model_path)
return TrainingResult(
model_path=str(model_path),
trained_samples=int(features.shape[0]),
classes=[str(item) for item in encoder.classes_.tolist()],
accuracy=accuracy,
)
def predict_sample(
audio_file: str,
raw_dir: Path,
model_path: Path,
) -> PredictionResult:
if not model_path.exists():
raise FileNotFoundError(f"Model file was not found: {model_path}")
model_package = joblib.load(model_path)
model = model_package.get("model")
encoder = model_package.get("label_encoder")
audio_path = raw_dir / audio_file
if not audio_path.exists():
raise FileNotFoundError(f"Audio file was not found: {audio_path}")
audio, sample_rate = read_audio(audio_path)
feature_vector = extract_feature_vector(audio, sample_rate).reshape(1, -1)
if hasattr(model, "predict_proba"):
probabilities = model.predict_proba(feature_vector)[0]
best_probability_index = int(np.argmax(probabilities))
encoded_label = int(model.classes_[best_probability_index])
confidence = float(probabilities[best_probability_index])
else:
encoded_label = int(model.predict(feature_vector)[0])
confidence = 1.0
predicted_label = str(encoder.inverse_transform([encoded_label])[0])
return PredictionResult(predicted_label=predicted_label, confidence=confidence)