118 lines
3.9 KiB
Python
118 lines
3.9 KiB
Python
"""
|
|
Speech-to-text transcription for recorded calls using OpenAI Whisper.
|
|
|
|
Audio is downloaded from GCS then sent to the Whisper API. Falls back to
|
|
returning None on any failure so the intelligence pipeline can still run.
|
|
"""
|
|
import asyncio
|
|
import tempfile
|
|
import os
|
|
from typing import Optional
|
|
from app.internal.logger import logger
|
|
from app.internal import firestore as fstore
|
|
|
|
# Whisper treats `prompt` as preceding transcript text, not instructions.
|
|
# Writing it as actual radio speech primes the vocabulary toward P25 codes
|
|
# and phrasing before the model hears the audio.
|
|
_WHISPER_PROMPT = (
|
|
"10-4. 10-23. 10-20. 10-97. 10-8. 10-7. 10-34. 10-50. 10-52. "
|
|
"Post 4, I'm out. Post 3. En route. On scene. In route. "
|
|
"Copy. Negative. Stand by. Be advised. Go ahead. "
|
|
"Units responding. Dispatch. Talkgroup. "
|
|
"Engine. Ladder. Medic. Rescue. Car. Unit. "
|
|
"MVA. MVC. Structure fire. Working fire."
|
|
)
|
|
|
|
|
|
async def transcribe_call(
|
|
call_id: str,
|
|
gcs_uri: str,
|
|
talkgroup_name: Optional[str] = None,
|
|
) -> tuple[Optional[str], list[dict]]:
|
|
"""
|
|
Transcribe audio at the given GCS URI and store the result in Firestore.
|
|
|
|
Returns:
|
|
(transcript, segments) — segments is a list of {start, end, text} dicts,
|
|
one per detected transmission. Empty list if transcription failed.
|
|
"""
|
|
if not gcs_uri or not gcs_uri.startswith("gs://"):
|
|
return None, []
|
|
|
|
try:
|
|
transcript, segments = await asyncio.to_thread(_sync_transcribe, gcs_uri)
|
|
except Exception as e:
|
|
logger.warning(f"Transcription failed for call {call_id}: {e}")
|
|
return None, []
|
|
|
|
if transcript:
|
|
updates: dict = {"transcript": transcript}
|
|
if segments:
|
|
updates["segments"] = segments
|
|
try:
|
|
await fstore.doc_set("calls", call_id, updates)
|
|
logger.info(
|
|
f"Transcript saved for call {call_id} "
|
|
f"({len(transcript)} chars, {len(segments)} segment(s))"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Could not save transcript for {call_id}: {e}")
|
|
|
|
return transcript, segments
|
|
|
|
|
|
def _sync_transcribe(gcs_uri: str) -> tuple[Optional[str], list[dict]]:
|
|
"""Download audio from GCS and transcribe with OpenAI Whisper."""
|
|
from google.cloud import storage as gcs
|
|
from google.oauth2 import service_account
|
|
from openai import OpenAI
|
|
from app.config import settings
|
|
|
|
if not settings.openai_api_key:
|
|
logger.warning("OPENAI_API_KEY not set — transcription disabled.")
|
|
return None
|
|
|
|
without_scheme = gcs_uri[len("gs://"):]
|
|
bucket_name, blob_path = without_scheme.split("/", 1)
|
|
|
|
if settings.gcp_credentials_path:
|
|
creds = service_account.Credentials.from_service_account_file(
|
|
settings.gcp_credentials_path,
|
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
)
|
|
gcs_client = gcs.Client(credentials=creds)
|
|
else:
|
|
gcs_client = gcs.Client()
|
|
|
|
bucket = gcs_client.bucket(bucket_name)
|
|
blob = bucket.blob(blob_path)
|
|
|
|
suffix = os.path.splitext(blob_path)[1] or ".mp3"
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
blob.download_to_filename(tmp_path)
|
|
|
|
openai_client = OpenAI(api_key=settings.openai_api_key)
|
|
with open(tmp_path, "rb") as f:
|
|
response = openai_client.audio.transcriptions.create(
|
|
model="whisper-1",
|
|
file=f,
|
|
language="en",
|
|
prompt=_WHISPER_PROMPT,
|
|
response_format="verbose_json",
|
|
)
|
|
text = response.text.strip() or None
|
|
segments = [
|
|
{"start": round(s.start, 2), "end": round(s.end, 2), "text": s.text.strip()}
|
|
for s in (response.segments or [])
|
|
if s.text.strip()
|
|
]
|
|
return text, segments
|
|
finally:
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except OSError:
|
|
pass
|