feat: add storage and training modules for snore detection
- Implemented `storage.py` for managing metadata storage, including sample addition, retrieval, and review state management. - Created `training.py` for training a local model using Random Forest, including functions for training and predicting samples. - Developed a web interface in `app.js` for capturing audio samples, managing labels, and training the model. - Added HTML structure in `index.html` for the SnoreStopper control room with sections for sample capture, overnight gathering, training, and status display. - Styled the application with `styles.css` to enhance user experience and interface aesthetics.
This commit is contained in:
4
src/snorestopper/__init__.py
Normal file
4
src/snorestopper/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""SnoreStopper package."""
|
||||
|
||||
__all__ = ["__version__"]
|
||||
__version__ = "0.1.0"
|
||||
68
src/snorestopper/audio.py
Normal file
68
src/snorestopper/audio.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except Exception:
|
||||
sd = None
|
||||
|
||||
|
||||
def _require_sounddevice() -> None:
|
||||
if sd is None:
|
||||
raise RuntimeError(
|
||||
"sounddevice is unavailable. Install requirements and ensure PortAudio is available."
|
||||
)
|
||||
|
||||
|
||||
def list_input_devices() -> list[dict[str, Any]]:
|
||||
_require_sounddevice()
|
||||
devices = sd.query_devices()
|
||||
|
||||
input_devices: list[dict[str, Any]] = []
|
||||
for index, device in enumerate(devices):
|
||||
max_input_channels = int(device.get("max_input_channels", 0))
|
||||
if max_input_channels <= 0:
|
||||
continue
|
||||
|
||||
input_devices.append(
|
||||
{
|
||||
"index": index,
|
||||
"name": str(device.get("name", f"Input Device {index}")),
|
||||
"max_input_channels": max_input_channels,
|
||||
"default_samplerate": float(device.get("default_samplerate", 0.0)),
|
||||
}
|
||||
)
|
||||
|
||||
return input_devices
|
||||
|
||||
|
||||
def capture_audio(
|
||||
duration_seconds: int,
|
||||
sample_rate: int,
|
||||
channels: int,
|
||||
device_index: int | None = None,
|
||||
) -> np.ndarray:
|
||||
_require_sounddevice()
|
||||
|
||||
if duration_seconds <= 0:
|
||||
raise RuntimeError("duration_seconds must be positive")
|
||||
|
||||
frame_count = int(duration_seconds * sample_rate)
|
||||
recording = sd.rec(
|
||||
frame_count,
|
||||
samplerate=sample_rate,
|
||||
channels=channels,
|
||||
dtype="float32",
|
||||
device=device_index,
|
||||
)
|
||||
sd.wait()
|
||||
|
||||
audio = np.asarray(recording, dtype=np.float32)
|
||||
if audio.ndim == 2:
|
||||
# Collapse to mono for simpler feature extraction and training.
|
||||
audio = np.mean(audio, axis=1)
|
||||
|
||||
return audio
|
||||
43
src/snorestopper/config.py
Normal file
43
src/snorestopper/config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
project_root: Path
|
||||
data_dir: Path
|
||||
raw_dir: Path
|
||||
spectrogram_dir: Path
|
||||
models_dir: Path
|
||||
metadata_dir: Path
|
||||
web_dir: Path
|
||||
sample_rate: int
|
||||
channels: int
|
||||
min_duration_seconds: int
|
||||
max_duration_seconds: int
|
||||
model_file_name: str
|
||||
|
||||
|
||||
def load_settings() -> Settings:
|
||||
project_root = Path(
|
||||
os.getenv("SNORESTOPPER_ROOT", Path(__file__).resolve().parents[2])
|
||||
)
|
||||
data_dir = project_root / "data"
|
||||
|
||||
return Settings(
|
||||
project_root=project_root,
|
||||
data_dir=data_dir,
|
||||
raw_dir=data_dir / "raw",
|
||||
spectrogram_dir=data_dir / "spectrograms",
|
||||
models_dir=data_dir / "models",
|
||||
metadata_dir=data_dir / "meta",
|
||||
web_dir=project_root / "web",
|
||||
sample_rate=int(os.getenv("SNORESTOPPER_SAMPLE_RATE", "16000")),
|
||||
channels=int(os.getenv("SNORESTOPPER_CHANNELS", "1")),
|
||||
min_duration_seconds=int(os.getenv("SNORESTOPPER_MIN_DURATION", "2")),
|
||||
max_duration_seconds=int(os.getenv("SNORESTOPPER_MAX_DURATION", "90")),
|
||||
model_file_name=os.getenv("SNORESTOPPER_MODEL_FILE", "snore_classifier.joblib"),
|
||||
)
|
||||
95
src/snorestopper/features.py
Normal file
95
src/snorestopper/features.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from scipy import signal
|
||||
|
||||
|
||||
def save_audio_as_flac(audio: np.ndarray, sample_rate: int, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(file=str(output_path), data=audio, samplerate=sample_rate, format="FLAC")
|
||||
|
||||
|
||||
def read_audio(audio_path: Path) -> tuple[np.ndarray, int]:
|
||||
audio, sample_rate = sf.read(str(audio_path), always_2d=False)
|
||||
waveform = np.asarray(audio, dtype=np.float32)
|
||||
if waveform.ndim == 2:
|
||||
waveform = np.mean(waveform, axis=1)
|
||||
return waveform, int(sample_rate)
|
||||
|
||||
|
||||
def _compute_spectrogram(audio: np.ndarray, sample_rate: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
n_samples = int(audio.shape[0])
|
||||
if n_samples <= 1:
|
||||
nperseg = 1
|
||||
noverlap = 0
|
||||
else:
|
||||
nperseg = min(1024, n_samples)
|
||||
noverlap = min(int(nperseg * 0.5), nperseg - 1)
|
||||
|
||||
return signal.spectrogram(
|
||||
audio,
|
||||
fs=sample_rate,
|
||||
nperseg=nperseg,
|
||||
noverlap=noverlap,
|
||||
scaling="spectrum",
|
||||
mode="magnitude",
|
||||
)
|
||||
|
||||
|
||||
def save_spectrogram(audio: np.ndarray, sample_rate: int, output_path: Path) -> None:
|
||||
frequencies, times, spectrum = _compute_spectrogram(audio, sample_rate)
|
||||
power_db = 10.0 * np.log10(spectrum + 1e-12)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig, axis = plt.subplots(figsize=(8, 3))
|
||||
mesh = axis.pcolormesh(times, frequencies, power_db, shading="gouraud", cmap="magma")
|
||||
axis.set_title("SnoreStopper Spectrogram")
|
||||
axis.set_ylabel("Frequency (Hz)")
|
||||
axis.set_xlabel("Time (s)")
|
||||
fig.colorbar(mesh, ax=axis, label="Power (dB)")
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _band_energy(power: np.ndarray, frequencies: np.ndarray, low: float, high: float) -> float:
|
||||
mask = (frequencies >= low) & (frequencies < high)
|
||||
if not np.any(mask):
|
||||
return 0.0
|
||||
return float(np.mean(power[mask, :]))
|
||||
|
||||
|
||||
def extract_feature_vector(audio: np.ndarray, sample_rate: int) -> np.ndarray:
|
||||
frequencies, _, spectrum = _compute_spectrogram(audio, sample_rate)
|
||||
power = np.log1p(spectrum)
|
||||
|
||||
mean_power = float(np.mean(power))
|
||||
std_power = float(np.std(power))
|
||||
max_power = float(np.max(power))
|
||||
|
||||
low_band = _band_energy(power, frequencies, 20.0, 250.0)
|
||||
mid_band = _band_energy(power, frequencies, 250.0, 1000.0)
|
||||
high_band = _band_energy(power, frequencies, 1000.0, 4000.0)
|
||||
|
||||
spectral_weights = np.mean(power, axis=1) + 1e-9
|
||||
spectral_centroid = float(np.average(frequencies, weights=spectral_weights))
|
||||
|
||||
return np.asarray(
|
||||
[
|
||||
mean_power,
|
||||
std_power,
|
||||
max_power,
|
||||
low_band,
|
||||
mid_band,
|
||||
high_band,
|
||||
spectral_centroid,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
555
src/snorestopper/main.py
Normal file
555
src/snorestopper/main.py
Normal file
@@ -0,0 +1,555 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from .audio import capture_audio, list_input_devices
|
||||
from .config import Settings, load_settings
|
||||
from .features import save_audio_as_flac, save_spectrogram
|
||||
from .overnight import OvernightCaptureController
|
||||
from .schemas import (
|
||||
ApiMessage,
|
||||
CaptureRequest,
|
||||
DetectRequest,
|
||||
DetectResponse,
|
||||
DeviceInfo,
|
||||
LabelRequest,
|
||||
OvernightStartRequest,
|
||||
OvernightStatusResponse,
|
||||
SampleRecord,
|
||||
TrainRequest,
|
||||
TrainResponse,
|
||||
)
|
||||
from .storage import MetadataStore, ensure_directories
|
||||
from .training import predict_sample, train_local_model
|
||||
|
||||
settings = load_settings()
|
||||
metadata_file = ensure_directories(settings)
|
||||
store = MetadataStore(metadata_file)
|
||||
|
||||
BINARY_LABELS = {"snore", "not_snore"}
|
||||
VALID_SAMPLE_LABELS = {"snore", "not_snore", "unclear"}
|
||||
REVIEW_STATES = {"none", "pending", "approved", "rejected", "manual"}
|
||||
LABEL_SOURCES = {"none", "manual", "watch"}
|
||||
|
||||
app = FastAPI(
|
||||
title="SnoreStopper",
|
||||
version="0.1.0",
|
||||
description="Self-hosted snore tracking, labeling, and model training",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.mount("/web", StaticFiles(directory=settings.web_dir), name="web")
|
||||
app.mount("/data", StaticFiles(directory=settings.data_dir), name="data")
|
||||
|
||||
|
||||
def _to_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _to_optional_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _invert_binary_label(label: str) -> str:
|
||||
if label == "snore":
|
||||
return "not_snore"
|
||||
if label == "not_snore":
|
||||
return "snore"
|
||||
raise ValueError(f"Unsupported binary label: {label}")
|
||||
|
||||
|
||||
def _to_sample_record(record: dict[str, Any]) -> SampleRecord:
|
||||
audio_file = str(record.get("audio_file", ""))
|
||||
spectrogram_file = str(record.get("spectrogram_file", ""))
|
||||
label = record.get("label")
|
||||
if label not in VALID_SAMPLE_LABELS:
|
||||
label = None
|
||||
|
||||
proposed_label = record.get("proposed_label")
|
||||
if proposed_label not in BINARY_LABELS:
|
||||
proposed_label = None
|
||||
|
||||
review_state = str(record.get("review_state", "none"))
|
||||
if review_state not in REVIEW_STATES:
|
||||
review_state = "none"
|
||||
|
||||
label_source = str(record.get("label_source", "none"))
|
||||
if label_source not in LABEL_SOURCES:
|
||||
label_source = "none"
|
||||
|
||||
return SampleRecord(
|
||||
sample_id=str(record.get("sample_id", "")),
|
||||
created_at=str(record.get("created_at", "")),
|
||||
duration_seconds=_to_int(record.get("duration_seconds"), default=0),
|
||||
sample_rate=_to_int(record.get("sample_rate"), default=settings.sample_rate),
|
||||
device_index=record.get("device_index"),
|
||||
tag=record.get("tag"),
|
||||
label=label,
|
||||
audio_file=audio_file,
|
||||
spectrogram_file=spectrogram_file,
|
||||
audio_url=f"/data/raw/{audio_file}",
|
||||
spectrogram_url=f"/data/spectrograms/{spectrogram_file}",
|
||||
proposed_label=proposed_label,
|
||||
proposed_confidence=_to_optional_float(record.get("proposed_confidence")),
|
||||
review_state=review_state,
|
||||
training_approved=bool(record.get("training_approved", False)),
|
||||
label_source=label_source,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_model_path(model_path_value: str | None) -> Path:
|
||||
if not model_path_value:
|
||||
return settings.models_dir / settings.model_file_name
|
||||
|
||||
candidate = Path(model_path_value).expanduser()
|
||||
if not candidate.is_absolute():
|
||||
candidate = settings.project_root / candidate
|
||||
return candidate
|
||||
|
||||
|
||||
def _validate_capture_duration_seconds(duration_seconds: int, app_settings: Settings) -> None:
|
||||
if duration_seconds < app_settings.min_duration_seconds:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"duration_seconds must be at least {app_settings.min_duration_seconds} "
|
||||
"seconds"
|
||||
),
|
||||
)
|
||||
|
||||
if duration_seconds > app_settings.max_duration_seconds:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"duration_seconds must be {app_settings.max_duration_seconds} "
|
||||
"seconds or less"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _capture_and_store_sample(
|
||||
*,
|
||||
duration_seconds: int,
|
||||
device_index: int | None,
|
||||
tag: str | None,
|
||||
watch_after_capture: bool,
|
||||
) -> dict[str, Any]:
|
||||
_validate_capture_duration_seconds(duration_seconds, settings)
|
||||
|
||||
try:
|
||||
waveform = capture_audio(
|
||||
duration_seconds=duration_seconds,
|
||||
sample_rate=settings.sample_rate,
|
||||
channels=settings.channels,
|
||||
device_index=device_index,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Recording failed: {exc}") from exc
|
||||
|
||||
sample_id = uuid4().hex
|
||||
audio_file = f"{sample_id}.flac"
|
||||
spectrogram_file = f"{sample_id}.png"
|
||||
|
||||
audio_path = settings.raw_dir / audio_file
|
||||
spectrogram_path = settings.spectrogram_dir / spectrogram_file
|
||||
|
||||
try:
|
||||
save_audio_as_flac(waveform, settings.sample_rate, audio_path)
|
||||
save_spectrogram(waveform, settings.sample_rate, spectrogram_path)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Saving sample failed: {exc}") from exc
|
||||
|
||||
record: dict[str, Any] = {
|
||||
"sample_id": sample_id,
|
||||
"created_at": datetime.now(tz=timezone.utc).isoformat(),
|
||||
"duration_seconds": duration_seconds,
|
||||
"sample_rate": settings.sample_rate,
|
||||
"device_index": device_index,
|
||||
"tag": tag,
|
||||
"label": None,
|
||||
"audio_file": audio_file,
|
||||
"spectrogram_file": spectrogram_file,
|
||||
"proposed_label": None,
|
||||
"proposed_confidence": None,
|
||||
"review_state": "none",
|
||||
"training_approved": False,
|
||||
"label_source": "none",
|
||||
}
|
||||
|
||||
if watch_after_capture:
|
||||
model_path = settings.models_dir / settings.model_file_name
|
||||
if model_path.exists():
|
||||
try:
|
||||
prediction = predict_sample(
|
||||
audio_file=audio_file,
|
||||
raw_dir=settings.raw_dir,
|
||||
model_path=model_path,
|
||||
)
|
||||
if prediction.predicted_label in BINARY_LABELS:
|
||||
record.update(
|
||||
{
|
||||
"proposed_label": prediction.predicted_label,
|
||||
"proposed_confidence": prediction.confidence,
|
||||
"review_state": "pending",
|
||||
"training_approved": False,
|
||||
"label_source": "watch",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Capture should still be retained even when watch inference fails.
|
||||
pass
|
||||
|
||||
store.add_sample(record)
|
||||
return record
|
||||
|
||||
|
||||
def _background_capture_job(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return _capture_and_store_sample(
|
||||
duration_seconds=_to_int(payload.get("duration_seconds"), default=10),
|
||||
device_index=payload.get("device_index"),
|
||||
tag=str(payload.get("tag")) if payload.get("tag") is not None else None,
|
||||
watch_after_capture=bool(payload.get("watch_after_capture", False)),
|
||||
)
|
||||
|
||||
|
||||
overnight_controller = OvernightCaptureController(
|
||||
run_capture_job=_background_capture_job,
|
||||
pending_review_count=store.count_pending_review_samples,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/", include_in_schema=False)
|
||||
def root_redirect() -> RedirectResponse:
|
||||
return RedirectResponse(url="/web/index.html")
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health() -> dict[str, Any]:
|
||||
return {
|
||||
"status": "ok",
|
||||
"project_root": str(settings.project_root),
|
||||
"sample_rate": settings.sample_rate,
|
||||
"channels": settings.channels,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/audio/devices", response_model=list[DeviceInfo])
|
||||
def get_audio_devices() -> list[DeviceInfo]:
|
||||
try:
|
||||
devices = list_input_devices()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Unable to query input devices: {exc}") from exc
|
||||
|
||||
return [DeviceInfo(**item) for item in devices]
|
||||
|
||||
|
||||
@app.get("/api/samples", response_model=list[SampleRecord])
|
||||
def list_samples() -> list[SampleRecord]:
|
||||
return [_to_sample_record(item) for item in store.list_samples()]
|
||||
|
||||
|
||||
@app.post("/api/samples/capture", response_model=SampleRecord)
|
||||
def capture_sample(payload: CaptureRequest) -> SampleRecord:
|
||||
record = _capture_and_store_sample(
|
||||
duration_seconds=payload.duration_seconds,
|
||||
device_index=payload.device_index,
|
||||
tag=payload.tag,
|
||||
watch_after_capture=payload.watch_after_capture,
|
||||
)
|
||||
return _to_sample_record(record)
|
||||
|
||||
|
||||
@app.post("/api/samples/{sample_id}/label", response_model=SampleRecord)
|
||||
def label_sample(sample_id: str, payload: LabelRequest) -> SampleRecord:
|
||||
training_approved = payload.label in BINARY_LABELS
|
||||
try:
|
||||
updated = store.update_sample_fields(
|
||||
sample_id=sample_id,
|
||||
updates={
|
||||
"label": payload.label,
|
||||
"review_state": "manual",
|
||||
"training_approved": training_approved,
|
||||
"label_source": "manual",
|
||||
"proposed_label": None,
|
||||
"proposed_confidence": None,
|
||||
},
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
return _to_sample_record(updated)
|
||||
|
||||
|
||||
@app.post("/api/watch/{sample_id}/propose", response_model=SampleRecord)
|
||||
def propose_watch_label(sample_id: str) -> SampleRecord:
|
||||
sample = store.get_sample(sample_id)
|
||||
if sample is None:
|
||||
raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found")
|
||||
|
||||
if bool(sample.get("training_approved", False)):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Sample is already approved for training. Clear or relabel it before watch proposal.",
|
||||
)
|
||||
|
||||
model_path = settings.models_dir / settings.model_file_name
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=(
|
||||
"No trained model found. Train a model first, then run snore watch proposals."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
prediction = predict_sample(
|
||||
audio_file=str(sample.get("audio_file", "")),
|
||||
raw_dir=settings.raw_dir,
|
||||
model_path=model_path,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Watch proposal failed: {exc}") from exc
|
||||
|
||||
if prediction.predicted_label not in BINARY_LABELS:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
"Model predicted an unsupported class. Retrain with labels 'snore' and 'not_snore'."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
updated = store.update_sample_fields(
|
||||
sample_id=sample_id,
|
||||
updates={
|
||||
"label": None,
|
||||
"proposed_label": prediction.predicted_label,
|
||||
"proposed_confidence": prediction.confidence,
|
||||
"review_state": "pending",
|
||||
"training_approved": False,
|
||||
"label_source": "watch",
|
||||
},
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
return _to_sample_record(updated)
|
||||
|
||||
|
||||
@app.get("/api/watch/pending", response_model=list[SampleRecord])
|
||||
def list_pending_watch_reviews() -> list[SampleRecord]:
|
||||
pending_samples = store.list_pending_review_samples()
|
||||
return [_to_sample_record(item) for item in pending_samples]
|
||||
|
||||
|
||||
@app.post("/api/watch/{sample_id}/thumbs-up", response_model=SampleRecord)
|
||||
def approve_watch_prediction(sample_id: str) -> SampleRecord:
|
||||
sample = store.get_sample(sample_id)
|
||||
if sample is None:
|
||||
raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found")
|
||||
|
||||
proposed_label = sample.get("proposed_label")
|
||||
if proposed_label not in BINARY_LABELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Sample has no pending watch proposal to approve",
|
||||
)
|
||||
|
||||
try:
|
||||
updated = store.update_sample_fields(
|
||||
sample_id=sample_id,
|
||||
updates={
|
||||
"label": proposed_label,
|
||||
"review_state": "approved",
|
||||
"training_approved": True,
|
||||
"label_source": "watch",
|
||||
},
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
return _to_sample_record(updated)
|
||||
|
||||
|
||||
@app.post("/api/watch/{sample_id}/thumbs-down", response_model=SampleRecord)
|
||||
def reject_watch_prediction(sample_id: str) -> SampleRecord:
|
||||
sample = store.get_sample(sample_id)
|
||||
if sample is None:
|
||||
raise HTTPException(status_code=404, detail=f"Sample '{sample_id}' was not found")
|
||||
|
||||
if sample.get("review_state") != "pending":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Sample is not currently pending watch review",
|
||||
)
|
||||
|
||||
proposed_label = sample.get("proposed_label")
|
||||
if proposed_label not in BINARY_LABELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Sample has no pending watch proposal to invert",
|
||||
)
|
||||
|
||||
try:
|
||||
corrected_label = _invert_binary_label(str(proposed_label))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
try:
|
||||
updated = store.update_sample_fields(
|
||||
sample_id=sample_id,
|
||||
updates={
|
||||
"label": corrected_label,
|
||||
"review_state": "approved",
|
||||
"training_approved": True,
|
||||
"label_source": "watch",
|
||||
},
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
return _to_sample_record(updated)
|
||||
|
||||
|
||||
@app.post("/api/train", response_model=TrainResponse)
|
||||
def train_model(payload: TrainRequest) -> TrainResponse:
|
||||
model_path = settings.models_dir / settings.model_file_name
|
||||
samples = store.list_samples()
|
||||
|
||||
try:
|
||||
result = train_local_model(
|
||||
samples=samples,
|
||||
raw_dir=settings.raw_dir,
|
||||
model_path=model_path,
|
||||
test_size=payload.test_size,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Training failed: {exc}") from exc
|
||||
|
||||
return TrainResponse(
|
||||
model_path=result.model_path,
|
||||
trained_samples=result.trained_samples,
|
||||
classes=result.classes,
|
||||
accuracy=result.accuracy,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/overnight/status", response_model=OvernightStatusResponse)
|
||||
def overnight_status() -> OvernightStatusResponse:
|
||||
return OvernightStatusResponse(**overnight_controller.get_status())
|
||||
|
||||
|
||||
@app.post("/api/overnight/start", response_model=OvernightStatusResponse)
|
||||
def start_overnight_capture(payload: OvernightStartRequest) -> OvernightStatusResponse:
|
||||
_validate_capture_duration_seconds(payload.clip_duration_seconds, settings)
|
||||
|
||||
if payload.interval_seconds < payload.clip_duration_seconds:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="interval_seconds must be greater than or equal to clip_duration_seconds",
|
||||
)
|
||||
|
||||
if payload.auto_watch:
|
||||
model_path = settings.models_dir / settings.model_file_name
|
||||
if not model_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"auto_watch is enabled but no trained model was found. Train a model "
|
||||
"first or disable auto_watch."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
status = overnight_controller.start_session(
|
||||
duration_hours=payload.duration_hours,
|
||||
clip_duration_seconds=payload.clip_duration_seconds,
|
||||
interval_seconds=payload.interval_seconds,
|
||||
device_index=payload.device_index,
|
||||
tag_prefix=payload.tag_prefix,
|
||||
auto_watch=payload.auto_watch,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
|
||||
return OvernightStatusResponse(**status)
|
||||
|
||||
|
||||
@app.post("/api/overnight/stop", response_model=OvernightStatusResponse)
|
||||
def stop_overnight_capture() -> OvernightStatusResponse:
|
||||
return OvernightStatusResponse(**overnight_controller.stop_session())
|
||||
|
||||
|
||||
@app.post("/api/detect", response_model=DetectResponse)
|
||||
def detect_sample(payload: DetectRequest) -> DetectResponse:
|
||||
sample = store.get_sample(payload.sample_id)
|
||||
if sample is None:
|
||||
raise HTTPException(status_code=404, detail=f"Sample '{payload.sample_id}' was not found")
|
||||
|
||||
model_path = _resolve_model_path(payload.model_path)
|
||||
try:
|
||||
prediction = predict_sample(
|
||||
audio_file=str(sample["audio_file"]),
|
||||
raw_dir=settings.raw_dir,
|
||||
model_path=model_path,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Detection failed: {exc}") from exc
|
||||
|
||||
return DetectResponse(
|
||||
sample_id=payload.sample_id,
|
||||
predicted_label=prediction.predicted_label,
|
||||
confidence=prediction.confidence,
|
||||
model_path=str(model_path),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/reset", response_model=ApiMessage)
|
||||
def reset_sample_labels() -> ApiMessage:
|
||||
# Keep recordings but reset labels to quickly restart training experiments.
|
||||
samples = store.list_samples()
|
||||
for sample in samples:
|
||||
if sample.get("sample_id"):
|
||||
store.update_sample_fields(
|
||||
sample_id=str(sample["sample_id"]),
|
||||
updates={
|
||||
"label": None,
|
||||
"proposed_label": None,
|
||||
"proposed_confidence": None,
|
||||
"review_state": "none",
|
||||
"training_approved": False,
|
||||
"label_source": "none",
|
||||
},
|
||||
)
|
||||
|
||||
return ApiMessage(message="All sample labels and watch proposals were reset")
|
||||
167
src/snorestopper/overnight.py
Normal file
167
src/snorestopper/overnight.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class OvernightCaptureController:
|
||||
def __init__(
|
||||
self,
|
||||
run_capture_job: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
pending_review_count: Callable[[], int],
|
||||
):
|
||||
self._run_capture_job = run_capture_job
|
||||
self._pending_review_count = pending_review_count
|
||||
self._lock = Lock()
|
||||
self._stop_event = Event()
|
||||
self._thread: Thread | None = None
|
||||
self._state: dict[str, Any] = {
|
||||
"running": False,
|
||||
"session_id": None,
|
||||
"started_at": None,
|
||||
"planned_end_at": None,
|
||||
"next_capture_at": None,
|
||||
"duration_hours": None,
|
||||
"clip_duration_seconds": None,
|
||||
"interval_seconds": None,
|
||||
"auto_watch": None,
|
||||
"captured_samples": 0,
|
||||
"last_error": None,
|
||||
}
|
||||
|
||||
def start_session(
|
||||
self,
|
||||
*,
|
||||
duration_hours: int,
|
||||
clip_duration_seconds: int,
|
||||
interval_seconds: int,
|
||||
device_index: int | None,
|
||||
tag_prefix: str,
|
||||
auto_watch: bool,
|
||||
) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
if self._state["running"]:
|
||||
raise RuntimeError("Overnight capture is already running")
|
||||
|
||||
session_id = uuid4().hex
|
||||
started_at = datetime.now(tz=timezone.utc)
|
||||
planned_end_at = started_at + timedelta(hours=duration_hours)
|
||||
|
||||
self._stop_event = Event()
|
||||
self._state = {
|
||||
"running": True,
|
||||
"session_id": session_id,
|
||||
"started_at": started_at.isoformat(),
|
||||
"planned_end_at": planned_end_at.isoformat(),
|
||||
"next_capture_at": started_at.isoformat(),
|
||||
"duration_hours": duration_hours,
|
||||
"clip_duration_seconds": clip_duration_seconds,
|
||||
"interval_seconds": interval_seconds,
|
||||
"auto_watch": auto_watch,
|
||||
"captured_samples": 0,
|
||||
"last_error": None,
|
||||
}
|
||||
|
||||
worker = Thread(
|
||||
target=self._run_loop,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
"planned_end_at": planned_end_at,
|
||||
"clip_duration_seconds": clip_duration_seconds,
|
||||
"interval_seconds": interval_seconds,
|
||||
"device_index": device_index,
|
||||
"tag_prefix": tag_prefix,
|
||||
"auto_watch": auto_watch,
|
||||
"stop_event": self._stop_event,
|
||||
},
|
||||
daemon=True,
|
||||
name="snorestopper-overnight-capture",
|
||||
)
|
||||
self._thread = worker
|
||||
worker.start()
|
||||
|
||||
return self._status_locked()
|
||||
|
||||
def stop_session(self) -> dict[str, Any]:
|
||||
thread_to_join: Thread | None = None
|
||||
with self._lock:
|
||||
if not self._state["running"]:
|
||||
return self._status_locked()
|
||||
|
||||
self._stop_event.set()
|
||||
thread_to_join = self._thread
|
||||
|
||||
if thread_to_join is not None:
|
||||
thread_to_join.join(timeout=2.0)
|
||||
|
||||
with self._lock:
|
||||
return self._status_locked()
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return self._status_locked()
|
||||
|
||||
def _status_locked(self) -> dict[str, Any]:
|
||||
status = dict(self._state)
|
||||
status["pending_reviews"] = int(self._pending_review_count())
|
||||
return status
|
||||
|
||||
def _run_loop(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
planned_end_at: datetime,
|
||||
clip_duration_seconds: int,
|
||||
interval_seconds: int,
|
||||
device_index: int | None,
|
||||
tag_prefix: str,
|
||||
auto_watch: bool,
|
||||
stop_event: Event,
|
||||
) -> None:
|
||||
sample_number = 0
|
||||
next_capture_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if now >= planned_end_at:
|
||||
break
|
||||
|
||||
wait_seconds = (next_capture_at - now).total_seconds()
|
||||
if wait_seconds > 0:
|
||||
stop_event.wait(wait_seconds)
|
||||
continue
|
||||
|
||||
sample_number += 1
|
||||
capture_started_at = datetime.now(tz=timezone.utc)
|
||||
tag = f"{tag_prefix}-{session_id[:6]}-{sample_number:04d}"
|
||||
|
||||
try:
|
||||
self._run_capture_job(
|
||||
{
|
||||
"duration_seconds": clip_duration_seconds,
|
||||
"device_index": device_index,
|
||||
"tag": tag,
|
||||
"watch_after_capture": auto_watch,
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._state["captured_samples"] = sample_number
|
||||
self._state["last_error"] = None
|
||||
except Exception as exc:
|
||||
with self._lock:
|
||||
self._state["last_error"] = str(exc)
|
||||
|
||||
next_capture_at = capture_started_at + timedelta(seconds=interval_seconds)
|
||||
with self._lock:
|
||||
if next_capture_at < planned_end_at:
|
||||
self._state["next_capture_at"] = next_capture_at.isoformat()
|
||||
else:
|
||||
self._state["next_capture_at"] = None
|
||||
finally:
|
||||
with self._lock:
|
||||
self._state["running"] = False
|
||||
self._state["next_capture_at"] = None
|
||||
self._thread = None
|
||||
98
src/snorestopper/schemas.py
Normal file
98
src/snorestopper/schemas.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
SampleLabel = Literal["snore", "not_snore", "unclear"]
|
||||
BinaryLabel = Literal["snore", "not_snore"]
|
||||
ReviewState = Literal["none", "pending", "approved", "rejected", "manual"]
|
||||
LabelSource = Literal["none", "manual", "watch"]
|
||||
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
index: int
|
||||
name: str
|
||||
max_input_channels: int
|
||||
default_samplerate: float
|
||||
|
||||
|
||||
class CaptureRequest(BaseModel):
|
||||
duration_seconds: int = Field(default=10, ge=1, le=600)
|
||||
device_index: int | None = None
|
||||
tag: str | None = Field(default=None, max_length=80)
|
||||
watch_after_capture: bool = False
|
||||
|
||||
|
||||
class SampleRecord(BaseModel):
|
||||
sample_id: str
|
||||
created_at: str
|
||||
duration_seconds: int
|
||||
sample_rate: int
|
||||
device_index: int | None = None
|
||||
tag: str | None = None
|
||||
label: SampleLabel | None = None
|
||||
audio_file: str
|
||||
spectrogram_file: str
|
||||
audio_url: str
|
||||
spectrogram_url: str
|
||||
proposed_label: BinaryLabel | None = None
|
||||
proposed_confidence: float | None = None
|
||||
review_state: ReviewState = "none"
|
||||
training_approved: bool = False
|
||||
label_source: LabelSource = "none"
|
||||
|
||||
|
||||
class LabelRequest(BaseModel):
|
||||
label: SampleLabel
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
test_size: float = Field(default=0.25, ge=0.1, le=0.5)
|
||||
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
model_path: str
|
||||
trained_samples: int
|
||||
classes: list[str]
|
||||
accuracy: float | None = None
|
||||
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
sample_id: str
|
||||
model_path: str | None = None
|
||||
|
||||
|
||||
class DetectResponse(BaseModel):
|
||||
sample_id: str
|
||||
predicted_label: str
|
||||
confidence: float
|
||||
model_path: str
|
||||
|
||||
|
||||
class OvernightStartRequest(BaseModel):
|
||||
duration_hours: int = Field(default=8, ge=1, le=12)
|
||||
clip_duration_seconds: int = Field(default=20, ge=2, le=600)
|
||||
interval_seconds: int = Field(default=30, ge=2, le=600)
|
||||
device_index: int | None = None
|
||||
tag_prefix: str = Field(default="overnight", max_length=80)
|
||||
auto_watch: bool = True
|
||||
|
||||
|
||||
class OvernightStatusResponse(BaseModel):
|
||||
running: bool
|
||||
session_id: str | None = None
|
||||
started_at: str | None = None
|
||||
planned_end_at: str | None = None
|
||||
next_capture_at: str | None = None
|
||||
duration_hours: int | None = None
|
||||
clip_duration_seconds: int | None = None
|
||||
interval_seconds: int | None = None
|
||||
auto_watch: bool | None = None
|
||||
captured_samples: int = 0
|
||||
pending_reviews: int = 0
|
||||
last_error: str | None = None
|
||||
|
||||
|
||||
class ApiMessage(BaseModel):
|
||||
message: str
|
||||
119
src/snorestopper/storage.py
Normal file
119
src/snorestopper/storage.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from .config import Settings
|
||||
|
||||
|
||||
def ensure_directories(settings: Settings) -> Path:
|
||||
for path in (
|
||||
settings.data_dir,
|
||||
settings.raw_dir,
|
||||
settings.spectrogram_dir,
|
||||
settings.models_dir,
|
||||
settings.metadata_dir,
|
||||
):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metadata_file = settings.metadata_dir / "samples.json"
|
||||
if not metadata_file.exists():
|
||||
metadata_file.write_text("[]", encoding="utf-8")
|
||||
return metadata_file
|
||||
|
||||
|
||||
class MetadataStore:
|
||||
def __init__(self, metadata_file: Path):
|
||||
self.metadata_file = metadata_file
|
||||
self._lock = Lock()
|
||||
self.metadata_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not self.metadata_file.exists():
|
||||
self.metadata_file.write_text("[]", encoding="utf-8")
|
||||
|
||||
def _read(self) -> list[dict[str, Any]]:
|
||||
if not self.metadata_file.exists():
|
||||
return []
|
||||
|
||||
raw = self.metadata_file.read_text(encoding="utf-8").strip()
|
||||
if not raw:
|
||||
return []
|
||||
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if not isinstance(payload, list):
|
||||
return []
|
||||
|
||||
return [item for item in payload if isinstance(item, dict)]
|
||||
|
||||
def _write(self, records: list[dict[str, Any]]) -> None:
|
||||
self.metadata_file.write_text(
|
||||
json.dumps(records, indent=2, ensure_ascii=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def list_samples(self) -> list[dict[str, Any]]:
|
||||
with self._lock:
|
||||
records = self._read()
|
||||
|
||||
return sorted(records, key=lambda item: item.get("created_at", ""), reverse=True)
|
||||
|
||||
def get_sample(self, sample_id: str) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
records = self._read()
|
||||
|
||||
for record in records:
|
||||
if record.get("sample_id") == sample_id:
|
||||
return record
|
||||
return None
|
||||
|
||||
def add_sample(self, record: dict[str, Any]) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
records = self._read()
|
||||
records.append(record)
|
||||
self._write(records)
|
||||
|
||||
return record
|
||||
|
||||
def update_label(self, sample_id: str, label: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
records = self._read()
|
||||
for record in records:
|
||||
if record.get("sample_id") == sample_id:
|
||||
record["label"] = label
|
||||
self._write(records)
|
||||
return record
|
||||
|
||||
raise KeyError(f"Sample '{sample_id}' was not found")
|
||||
|
||||
def update_sample_fields(
|
||||
self,
|
||||
sample_id: str,
|
||||
updates: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
records = self._read()
|
||||
for record in records:
|
||||
if record.get("sample_id") == sample_id:
|
||||
record.update(updates)
|
||||
self._write(records)
|
||||
return record
|
||||
|
||||
raise KeyError(f"Sample '{sample_id}' was not found")
|
||||
|
||||
def list_pending_review_samples(self) -> list[dict[str, Any]]:
|
||||
with self._lock:
|
||||
records = self._read()
|
||||
|
||||
pending = [item for item in records if item.get("review_state") == "pending"]
|
||||
return sorted(pending, key=lambda item: item.get("created_at", ""), reverse=True)
|
||||
|
||||
def count_pending_review_samples(self) -> int:
|
||||
with self._lock:
|
||||
records = self._read()
|
||||
|
||||
return sum(1 for item in records if item.get("review_state") == "pending")
|
||||
164
src/snorestopper/training.py
Normal file
164
src/snorestopper/training.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
from .features import extract_feature_vector, read_audio
|
||||
|
||||
|
||||
def _is_training_eligible(sample: dict[str, Any]) -> bool:
|
||||
label = sample.get("label")
|
||||
if label not in {"snore", "not_snore"}:
|
||||
return False
|
||||
|
||||
approved_flag = sample.get("training_approved")
|
||||
if approved_flag is None:
|
||||
# Backward compatibility for samples created before review gating existed.
|
||||
return True
|
||||
|
||||
return bool(approved_flag)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainingResult:
|
||||
model_path: str
|
||||
trained_samples: int
|
||||
classes: list[str]
|
||||
accuracy: float | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredictionResult:
|
||||
predicted_label: str
|
||||
confidence: float
|
||||
|
||||
|
||||
def _build_training_matrix(
|
||||
samples: list[dict[str, Any]],
|
||||
raw_dir: Path,
|
||||
) -> tuple[np.ndarray, list[str]]:
|
||||
features: list[np.ndarray] = []
|
||||
labels: list[str] = []
|
||||
|
||||
for sample in samples:
|
||||
if not _is_training_eligible(sample):
|
||||
continue
|
||||
|
||||
label = sample.get("label")
|
||||
audio_file = sample.get("audio_file")
|
||||
if not label or not audio_file:
|
||||
continue
|
||||
|
||||
audio_path = raw_dir / str(audio_file)
|
||||
if not audio_path.exists():
|
||||
continue
|
||||
|
||||
audio, sample_rate = read_audio(audio_path)
|
||||
if audio.size == 0:
|
||||
continue
|
||||
|
||||
features.append(extract_feature_vector(audio, sample_rate))
|
||||
labels.append(str(label))
|
||||
|
||||
if len(features) < 2:
|
||||
raise ValueError("At least 2 labeled samples are required for training")
|
||||
|
||||
if len(set(labels)) < 2:
|
||||
raise ValueError("Training needs at least 2 distinct labels")
|
||||
|
||||
return np.vstack(features), labels
|
||||
|
||||
|
||||
def train_local_model(
|
||||
samples: list[dict[str, Any]],
|
||||
raw_dir: Path,
|
||||
model_path: Path,
|
||||
test_size: float = 0.25,
|
||||
) -> TrainingResult:
|
||||
features, labels = _build_training_matrix(samples, raw_dir)
|
||||
|
||||
encoder = LabelEncoder()
|
||||
encoded_labels = encoder.fit_transform(labels)
|
||||
|
||||
class_counts = np.bincount(encoded_labels)
|
||||
can_split = (
|
||||
len(encoded_labels) >= 8
|
||||
and len(class_counts) >= 2
|
||||
and bool(np.all(class_counts >= 2))
|
||||
)
|
||||
|
||||
if can_split:
|
||||
x_train, x_test, y_train, y_test = train_test_split(
|
||||
features,
|
||||
encoded_labels,
|
||||
test_size=test_size,
|
||||
random_state=42,
|
||||
stratify=encoded_labels,
|
||||
)
|
||||
else:
|
||||
x_train = features
|
||||
y_train = encoded_labels
|
||||
x_test = None
|
||||
y_test = None
|
||||
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=200,
|
||||
random_state=42,
|
||||
class_weight="balanced",
|
||||
)
|
||||
model.fit(x_train, y_train)
|
||||
|
||||
accuracy: float | None = None
|
||||
if x_test is not None and y_test is not None:
|
||||
predicted = model.predict(x_test)
|
||||
accuracy = float(accuracy_score(y_test, predicted))
|
||||
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump({"model": model, "label_encoder": encoder}, model_path)
|
||||
|
||||
return TrainingResult(
|
||||
model_path=str(model_path),
|
||||
trained_samples=int(features.shape[0]),
|
||||
classes=[str(item) for item in encoder.classes_.tolist()],
|
||||
accuracy=accuracy,
|
||||
)
|
||||
|
||||
|
||||
def predict_sample(
|
||||
audio_file: str,
|
||||
raw_dir: Path,
|
||||
model_path: Path,
|
||||
) -> PredictionResult:
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model file was not found: {model_path}")
|
||||
|
||||
model_package = joblib.load(model_path)
|
||||
model = model_package.get("model")
|
||||
encoder = model_package.get("label_encoder")
|
||||
|
||||
audio_path = raw_dir / audio_file
|
||||
if not audio_path.exists():
|
||||
raise FileNotFoundError(f"Audio file was not found: {audio_path}")
|
||||
|
||||
audio, sample_rate = read_audio(audio_path)
|
||||
feature_vector = extract_feature_vector(audio, sample_rate).reshape(1, -1)
|
||||
|
||||
if hasattr(model, "predict_proba"):
|
||||
probabilities = model.predict_proba(feature_vector)[0]
|
||||
best_probability_index = int(np.argmax(probabilities))
|
||||
encoded_label = int(model.classes_[best_probability_index])
|
||||
confidence = float(probabilities[best_probability_index])
|
||||
else:
|
||||
encoded_label = int(model.predict(feature_vector)[0])
|
||||
confidence = 1.0
|
||||
|
||||
predicted_label = str(encoder.inverse_transform([encoded_label])[0])
|
||||
return PredictionResult(predicted_label=predicted_label, confidence=confidence)
|
||||
Reference in New Issue
Block a user