Start to learn vocab from talkgroups to improve accuracy of STT
This commit is contained in:
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
Per-system vocabulary learning for STT accuracy improvement.
|
||||
|
||||
Three mechanisms:
|
||||
1. Bootstrap — one-shot GPT-4o call generates local knowledge at system setup:
|
||||
agencies + abbreviations, unit naming, streets, acronyms.
|
||||
2. Correction — diffs admin transcript edits, extracts corrected tokens → vocabulary.
|
||||
3. Induction — background loop samples N tokens of transcripts per system,
|
||||
asks GPT-4o-mini to propose new terms → queued as pending for review.
|
||||
|
||||
Firestore schema additions on system documents:
|
||||
vocabulary: list[str] — approved terms; injected into Whisper + GPT prompts
|
||||
vocabulary_pending: list[dict] — induction suggestions awaiting admin review
|
||||
each: {term, source, added_at}
|
||||
vocabulary_bootstrapped: bool — bootstrap has been run at least once
|
||||
"""
|
||||
import asyncio
|
||||
import difflib
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from app.internal.logger import logger
|
||||
from app.internal import firestore as fstore
|
||||
from app.config import settings
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Prompt templates
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_BOOTSTRAP_PROMPT = """\
|
||||
You are building a radio vocabulary dictionary to improve speech-to-text accuracy for a P25 \
|
||||
public-safety radio monitoring system in a specific area. The STT model has no local knowledge, \
|
||||
so common terms like "YVAC" get transcribed as "why vac", "5-baker" as "5 acre", etc.
|
||||
|
||||
System name: {system_name}
|
||||
System type: {system_type}
|
||||
Area context: {area_hint}
|
||||
|
||||
Return ONLY a JSON object:
|
||||
{{"vocabulary": [list of strings]}}
|
||||
|
||||
Include terms you are confident about for this area:
|
||||
- Agency names and their radio abbreviations (e.g. "YVAC" = Yorktown Volunteer Ambulance Corps)
|
||||
- Unit ID examples using the local naming convention (e.g. "5-baker", "5-charlie", "1-david";
|
||||
many departments use APCO phonetics: adam, baker, charles, david, edward, frank, george,
|
||||
henry, ida, john, king, lincoln, mary, nora, ocean, paul, queen, robert, sam, tom, union,
|
||||
victor, william, x-ray, young, zebra)
|
||||
- Major routes, roads, and key intersections
|
||||
- Local landmarks and geographic references dispatchers use
|
||||
- Agency-specific codes that differ from standard APCO
|
||||
|
||||
Return a flat list of strings — abbreviations, proper names, unit IDs, street names.
|
||||
Do NOT include common English words. Max 80 terms. Only include what you are confident is \
|
||||
accurate for this specific area; return fewer terms rather than guessing."""
|
||||
|
||||
_INDUCTION_PROMPT = """\
|
||||
You are analyzing P25 emergency radio transcripts to find vocabulary terms that should be \
|
||||
added to improve future speech-to-text accuracy for this system.
|
||||
|
||||
System: {system_name}
|
||||
Existing approved vocabulary (do not re-propose these): {existing_vocab}
|
||||
|
||||
Sampled transcripts:
|
||||
{transcript_block}
|
||||
|
||||
Find terms that are LIKELY STT errors or local terms missing from the vocabulary:
|
||||
- Unit IDs that appear garbled (e.g. "5 acre" → "5-baker")
|
||||
- Agency acronyms spelled out phonetically (e.g. "why vac" → "YVAC")
|
||||
- Street names or locations that look misspelled or oddly transcribed
|
||||
- Callsigns or local codes not yet in the vocabulary
|
||||
|
||||
Return ONLY a JSON object:
|
||||
{{"new_terms": ["term1", "term2", ...]}}
|
||||
|
||||
Only include high-confidence additions not already in existing vocabulary.
|
||||
Return {{"new_terms": []}} if nothing new is found."""
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Public API
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def bootstrap_system_vocabulary(system_id: str) -> list[str]:
|
||||
"""
|
||||
One-shot GPT-4o bootstrap: generate local-knowledge vocabulary for a system.
|
||||
Merges generated terms into system.vocabulary and sets vocabulary_bootstrapped=True.
|
||||
Returns the list of newly generated terms.
|
||||
"""
|
||||
system_doc = await fstore.doc_get("systems", system_id)
|
||||
if not system_doc:
|
||||
logger.warning(f"Vocabulary bootstrap: system {system_id} not found")
|
||||
return []
|
||||
|
||||
system_name = system_doc.get("name", "Unknown")
|
||||
system_type = system_doc.get("type", "P25")
|
||||
|
||||
# Build area hint from configured talkgroup names
|
||||
talkgroups = system_doc.get("config", {}).get("talkgroups", [])
|
||||
tg_names = [tg.get("name", "") for tg in talkgroups if tg.get("name")][:8]
|
||||
area_hint = f"Talkgroups include: {', '.join(tg_names)}" if tg_names else "Unknown area"
|
||||
|
||||
terms = await asyncio.to_thread(_sync_bootstrap, system_name, system_type, area_hint)
|
||||
if not terms:
|
||||
return []
|
||||
|
||||
existing = system_doc.get("vocabulary") or []
|
||||
existing_lower = {t.lower() for t in existing}
|
||||
to_add = [t for t in terms if t.lower() not in existing_lower]
|
||||
merged = list(dict.fromkeys(existing + to_add))
|
||||
|
||||
await fstore.doc_set("systems", system_id, {
|
||||
"vocabulary": merged,
|
||||
"vocabulary_bootstrapped": True,
|
||||
})
|
||||
logger.info(
|
||||
f"Vocabulary bootstrap: {len(to_add)} term(s) generated for system {system_id} "
|
||||
f"({system_name})"
|
||||
)
|
||||
return to_add
|
||||
|
||||
|
||||
async def learn_from_correction(system_id: str, original: str, corrected: str) -> None:
|
||||
"""
|
||||
Diff original and corrected transcripts; append new tokens to the approved vocabulary.
|
||||
Called automatically when an admin saves a transcript correction.
|
||||
"""
|
||||
if not system_id or not original or not corrected:
|
||||
return
|
||||
|
||||
new_terms = _diff_new_terms(original, corrected)
|
||||
if not new_terms:
|
||||
return
|
||||
|
||||
system_doc = await fstore.doc_get("systems", system_id)
|
||||
if not system_doc:
|
||||
return
|
||||
|
||||
existing = system_doc.get("vocabulary") or []
|
||||
existing_lower = {t.lower() for t in existing}
|
||||
to_add = [t for t in new_terms if t.lower() not in existing_lower]
|
||||
if not to_add:
|
||||
return
|
||||
|
||||
merged = list(dict.fromkeys(existing + to_add))
|
||||
await fstore.doc_set("systems", system_id, {"vocabulary": merged})
|
||||
logger.info(
|
||||
f"Vocabulary: learned {len(to_add)} term(s) from correction on system {system_id}: "
|
||||
f"{to_add}"
|
||||
)
|
||||
|
||||
|
||||
async def approve_pending_term(system_id: str, term: str) -> None:
|
||||
"""Move a pending term into the approved vocabulary."""
|
||||
system_doc = await fstore.doc_get("systems", system_id)
|
||||
if not system_doc:
|
||||
return
|
||||
pending = [p for p in (system_doc.get("vocabulary_pending") or []) if p["term"] != term]
|
||||
vocab = system_doc.get("vocabulary") or []
|
||||
if term.lower() not in {t.lower() for t in vocab}:
|
||||
vocab = list(dict.fromkeys(vocab + [term]))
|
||||
await fstore.doc_set("systems", system_id, {
|
||||
"vocabulary": vocab,
|
||||
"vocabulary_pending": pending,
|
||||
})
|
||||
|
||||
|
||||
async def dismiss_pending_term(system_id: str, term: str) -> None:
|
||||
"""Remove a pending term without adding it to vocabulary."""
|
||||
system_doc = await fstore.doc_get("systems", system_id)
|
||||
if not system_doc:
|
||||
return
|
||||
pending = [p for p in (system_doc.get("vocabulary_pending") or []) if p["term"] != term]
|
||||
await fstore.doc_set("systems", system_id, {"vocabulary_pending": pending})
|
||||
|
||||
|
||||
async def add_term(system_id: str, term: str) -> None:
|
||||
"""Manually add a term to the approved vocabulary."""
|
||||
system_doc = await fstore.doc_get("systems", system_id)
|
||||
if not system_doc:
|
||||
return
|
||||
vocab = system_doc.get("vocabulary") or []
|
||||
if term.lower() not in {t.lower() for t in vocab}:
|
||||
vocab = list(dict.fromkeys(vocab + [term.strip()]))
|
||||
await fstore.doc_set("systems", system_id, {"vocabulary": vocab})
|
||||
|
||||
|
||||
async def remove_term(system_id: str, term: str) -> None:
|
||||
"""Remove a term from the approved vocabulary."""
|
||||
system_doc = await fstore.doc_get("systems", system_id)
|
||||
if not system_doc:
|
||||
return
|
||||
vocab = [t for t in (system_doc.get("vocabulary") or []) if t.lower() != term.lower()]
|
||||
await fstore.doc_set("systems", system_id, {"vocabulary": vocab})
|
||||
|
||||
|
||||
async def get_vocabulary(system_id: str) -> dict:
|
||||
"""Return vocabulary and pending terms for a system."""
|
||||
doc = await fstore.doc_get("systems", system_id)
|
||||
if not doc:
|
||||
return {"vocabulary": [], "vocabulary_pending": [], "vocabulary_bootstrapped": False}
|
||||
return {
|
||||
"vocabulary": doc.get("vocabulary") or [],
|
||||
"vocabulary_pending": doc.get("vocabulary_pending") or [],
|
||||
"vocabulary_bootstrapped": doc.get("vocabulary_bootstrapped", False),
|
||||
}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Prompt-injection helpers (called by transcription.py and intelligence.py)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def build_whisper_vocab_prompt(vocabulary: list[str]) -> str:
|
||||
"""
|
||||
Format vocabulary for Whisper prompt injection.
|
||||
Whisper's prompt field acts as a context prior with a ~224-token limit.
|
||||
The base _WHISPER_PROMPT uses ~70 tokens; we budget ~150 tokens (≈550 chars) here.
|
||||
"""
|
||||
if not vocabulary:
|
||||
return ""
|
||||
char_budget = 550
|
||||
terms: list[str] = []
|
||||
used = 0
|
||||
for term in vocabulary:
|
||||
cost = len(term) + 2 # ", "
|
||||
if used + cost > char_budget:
|
||||
break
|
||||
terms.append(term)
|
||||
used += cost
|
||||
return ", ".join(terms) + ". " if terms else ""
|
||||
|
||||
|
||||
def build_gpt_vocab_block(vocabulary: list[str]) -> str:
|
||||
"""Format vocabulary for injection into GPT extraction prompts."""
|
||||
if not vocabulary:
|
||||
return ""
|
||||
return f"Known local terms: {', '.join(vocabulary)}\n"
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Background induction loop
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def vocabulary_induction_loop() -> None:
|
||||
interval = settings.vocabulary_induction_interval_hours * 3600
|
||||
logger.info(
|
||||
f"Vocabulary induction loop started — "
|
||||
f"interval: {settings.vocabulary_induction_interval_hours}h, "
|
||||
f"sample budget: {settings.vocabulary_induction_sample_tokens} tokens"
|
||||
)
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
try:
|
||||
await _run_induction_pass()
|
||||
except Exception as e:
|
||||
logger.error(f"Vocabulary induction pass failed: {e}")
|
||||
|
||||
|
||||
async def _run_induction_pass() -> None:
|
||||
systems = await fstore.collection_list("systems")
|
||||
if not systems:
|
||||
return
|
||||
logger.info(f"Vocabulary induction: processing {len(systems)} system(s)")
|
||||
for system in systems:
|
||||
system_id = system.get("system_id")
|
||||
if system_id:
|
||||
try:
|
||||
await _induct_system(system_id, system)
|
||||
except Exception as e:
|
||||
logger.warning(f"Induction failed for system {system_id}: {e}")
|
||||
|
||||
|
||||
async def _induct_system(system_id: str, system_doc: dict) -> None:
|
||||
"""Sample random transcripts for a system and propose new vocabulary."""
|
||||
system_name = system_doc.get("name", "Unknown")
|
||||
existing_vocab: list[str] = system_doc.get("vocabulary") or []
|
||||
|
||||
# Fetch recent ended calls for this system
|
||||
all_calls = await fstore.collection_list("calls", system_id=system_id, status="ended")
|
||||
if not all_calls:
|
||||
return
|
||||
|
||||
# Random sample up to the token budget (4 chars ≈ 1 token)
|
||||
random.shuffle(all_calls)
|
||||
char_budget = settings.vocabulary_induction_sample_tokens * 4
|
||||
transcript_block = ""
|
||||
sampled = 0
|
||||
for call in all_calls:
|
||||
text = call.get("transcript_corrected") or call.get("transcript") or ""
|
||||
if not text:
|
||||
continue
|
||||
if len(transcript_block) + len(text) > char_budget:
|
||||
break
|
||||
tg = call.get("talkgroup_name") or f"TGID {call.get('talkgroup_id', '?')}"
|
||||
transcript_block += f"[{tg}] {text}\n"
|
||||
sampled += 1
|
||||
|
||||
if sampled < 3:
|
||||
return # not enough data to learn from yet
|
||||
|
||||
new_terms = await asyncio.to_thread(
|
||||
_sync_induct, system_name, existing_vocab, transcript_block
|
||||
)
|
||||
if not new_terms:
|
||||
return
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
existing_pending: list[dict] = system_doc.get("vocabulary_pending") or []
|
||||
pending_lower = {p["term"].lower() for p in existing_pending}
|
||||
vocab_lower = {t.lower() for t in existing_vocab}
|
||||
|
||||
to_queue = [
|
||||
{"term": t, "source": "induction", "added_at": now}
|
||||
for t in new_terms
|
||||
if t.lower() not in vocab_lower and t.lower() not in pending_lower
|
||||
]
|
||||
if not to_queue:
|
||||
return
|
||||
|
||||
await fstore.doc_set("systems", system_id, {
|
||||
"vocabulary_pending": existing_pending + to_queue,
|
||||
})
|
||||
logger.info(
|
||||
f"Vocabulary induction: {len(to_queue)} new term(s) proposed for "
|
||||
f"system {system_id} ({system_name}): {[p['term'] for p in to_queue]}"
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Internal sync helpers
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_STOP_WORDS = {
|
||||
"the", "and", "for", "are", "was", "were", "this", "that", "with",
|
||||
"have", "has", "had", "but", "not", "from", "they", "will", "what",
|
||||
"can", "all", "been", "one", "two", "three", "four", "five", "six",
|
||||
"you", "out", "who", "get", "her", "him", "his", "its", "our", "my",
|
||||
"via", "per", "any", "now", "got", "she", "let", "did", "may", "yes",
|
||||
"sir", "say", "see", "too", "off", "how", "put", "set", "try", "back",
|
||||
"just", "like", "into", "than", "them", "then", "some", "also", "onto",
|
||||
"went", "over", "copy", "okay", "unit", "post", "road", "lane", "going",
|
||||
"being", "doing", "there", "their", "about", "would", "could", "should",
|
||||
"route", "north", "south", "east", "west", "avenue", "street", "drive",
|
||||
}
|
||||
|
||||
|
||||
def _diff_new_terms(original: str, corrected: str) -> list[str]:
|
||||
"""
|
||||
Token-level diff: find tokens in `corrected` that replaced or were inserted
|
||||
relative to `original`. These are the admin's intended spellings — good
|
||||
candidates for vocabulary.
|
||||
"""
|
||||
orig_tokens = original.split()
|
||||
corr_tokens = corrected.split()
|
||||
|
||||
matcher = difflib.SequenceMatcher(None,
|
||||
[t.lower() for t in orig_tokens],
|
||||
[t.lower() for t in corr_tokens],
|
||||
)
|
||||
new_terms: list[str] = []
|
||||
for tag, _i1, _i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag in ("insert", "replace"):
|
||||
for tok in corr_tokens[j1:j2]:
|
||||
clean = tok.strip(".,!?;:()'\"").strip("-")
|
||||
if len(clean) >= 3 and clean.lower() not in _STOP_WORDS:
|
||||
new_terms.append(clean)
|
||||
|
||||
return list(dict.fromkeys(new_terms))
|
||||
|
||||
|
||||
def _sync_bootstrap(system_name: str, system_type: str, area_hint: str) -> list[str]:
|
||||
from app.config import settings as cfg
|
||||
from openai import OpenAI
|
||||
|
||||
if not cfg.openai_api_key:
|
||||
return []
|
||||
|
||||
prompt = _BOOTSTRAP_PROMPT.format(
|
||||
system_name=system_name,
|
||||
system_type=system_type,
|
||||
area_hint=area_hint,
|
||||
)
|
||||
try:
|
||||
client = OpenAI(api_key=cfg.openai_api_key)
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
data = json.loads(response.choices[0].message.content)
|
||||
terms = data.get("vocabulary") or []
|
||||
return [str(t).strip() for t in terms if str(t).strip()]
|
||||
except Exception as e:
|
||||
logger.warning(f"Vocabulary bootstrap GPT call failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _sync_induct(
|
||||
system_name: str, existing_vocab: list[str], transcript_block: str
|
||||
) -> list[str]:
|
||||
from app.config import settings as cfg
|
||||
from openai import OpenAI
|
||||
|
||||
if not cfg.openai_api_key:
|
||||
return []
|
||||
|
||||
vocab_str = ", ".join(existing_vocab[:80]) if existing_vocab else "(none yet)"
|
||||
prompt = _INDUCTION_PROMPT.format(
|
||||
system_name=system_name,
|
||||
existing_vocab=vocab_str,
|
||||
transcript_block=transcript_block[:8000],
|
||||
)
|
||||
try:
|
||||
client = OpenAI(api_key=cfg.openai_api_key)
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
data = json.loads(response.choices[0].message.content)
|
||||
terms = data.get("new_terms") or []
|
||||
return [str(t).strip() for t in terms if str(t).strip()]
|
||||
except Exception as e:
|
||||
logger.warning(f"Vocabulary induction GPT call failed: {e}")
|
||||
return []
|
||||
Reference in New Issue
Block a user