Files
dictia-public/src/services/transcription/connectors/asr_endpoint.py

338 lines
13 KiB
Python

"""
ASR Endpoint connector for custom self-hosted ASR services.
Supports whisper-asr-webservice, WhisperX, and other compatible ASR services
that expose a /asr endpoint.
"""
import logging
import os
import httpx
from typing import Dict, Any, Set, Optional
from ..base import (
BaseTranscriptionConnector,
TranscriptionCapability,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionSegment,
ConnectorSpecifications,
)
from ..exceptions import TranscriptionError, ConfigurationError, ProviderError
from src.config.app_config import ASR_ENABLE_CHUNKING, ASR_MAX_DURATION_SECONDS
logger = logging.getLogger(__name__)
class ASREndpointConnector(BaseTranscriptionConnector):
"""Connector for custom ASR webservice (whisper-asr-webservice, WhisperX, etc.)."""
CAPABILITIES: Set[TranscriptionCapability] = {
TranscriptionCapability.DIARIZATION,
TranscriptionCapability.TIMESTAMPS,
TranscriptionCapability.LANGUAGE_DETECTION,
TranscriptionCapability.SPEAKER_COUNT_CONTROL, # Supports min/max speakers
}
PROVIDER_NAME = "asr_endpoint"
# SPECIFICATIONS is set dynamically in __init__ based on ASR_ENABLE_CHUNKING config
# Default values here for class-level reference (overridden per-instance)
SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=None,
max_duration_seconds=None,
handles_chunking_internally=True,
)
def __init__(self, config: Dict[str, Any]):
"""
Initialize the ASR Endpoint connector.
Args:
config: Configuration dict with keys:
- base_url: ASR service base URL (required)
- timeout: Request timeout in seconds (default: 1800)
- return_speaker_embeddings: Whether to request embeddings (default: False)
- diarize: Whether to enable diarization by default (default: True)
"""
super().__init__(config)
self.base_url = config['base_url'].rstrip('/')
self._config_timeout = config.get('timeout', 1800) # 30 minutes default
self.return_embeddings = config.get('return_speaker_embeddings', False)
self.default_diarize = config.get('diarize', True)
# Configure chunking behavior based on environment variables
# ASR_ENABLE_CHUNKING=true enables app-level chunking for self-hosted ASR services
# that may crash on long files due to GPU memory exhaustion
if ASR_ENABLE_CHUNKING:
# Calculate recommended chunk size (80% of max for safety margin)
recommended_chunk = int(ASR_MAX_DURATION_SECONDS * 0.8)
self.SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=None, # No file size limit
max_duration_seconds=ASR_MAX_DURATION_SECONDS,
handles_chunking_internally=False, # App handles chunking
recommended_chunk_seconds=recommended_chunk,
)
logger.info(
f"ASR chunking enabled: max_duration={ASR_MAX_DURATION_SECONDS}s, "
f"recommended_chunk={recommended_chunk}s"
)
else:
# Default behavior: ASR service handles everything internally
self.SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=None,
max_duration_seconds=None,
handles_chunking_internally=True,
)
# Add speaker embeddings capability if enabled
if self.return_embeddings:
self.CAPABILITIES = self.CAPABILITIES | {TranscriptionCapability.SPEAKER_EMBEDDINGS}
@property
def timeout(self):
"""Get ASR timeout, reading fresh from env/DB each time to respect runtime changes."""
# Environment variables take priority
env_timeout = os.environ.get('ASR_TIMEOUT') or os.environ.get('asr_timeout_seconds')
if env_timeout:
try:
return int(env_timeout)
except (ValueError, TypeError):
pass
# Try database setting (Admin UI)
try:
from src.models import SystemSetting
db_timeout = SystemSetting.get_setting('asr_timeout_seconds', None)
if db_timeout is not None:
return int(db_timeout)
except Exception:
pass
# Fall back to config value from initialization
return self._config_timeout
def _validate_config(self) -> None:
"""Validate required configuration."""
if not self.config.get('base_url'):
raise ConfigurationError("base_url is required for ASR endpoint connector")
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
"""
Transcribe audio using ASR webservice.
Args:
request: Standardized transcription request
Returns:
TranscriptionResponse with segments and speaker information
"""
try:
url = f"{self.base_url}/asr"
params = {
'encode': True,
'task': 'transcribe',
'output': 'json'
}
if request.language:
params['language'] = request.language
logger.info(f"Using transcription language: {request.language}")
# Determine if we should diarize
should_diarize = request.diarize if request.diarize is not None else self.default_diarize
# Send both parameter names for compatibility:
# - 'diarize' is used by whisper-asr-webservice
# - 'enable_diarization' is used by WhisperX
params['diarize'] = should_diarize
params['enable_diarization'] = should_diarize
if should_diarize and self.return_embeddings:
params['return_speaker_embeddings'] = True
if request.min_speakers:
params['min_speakers'] = request.min_speakers
if request.max_speakers:
params['max_speakers'] = request.max_speakers
if request.prompt:
params['initial_prompt'] = request.prompt
if request.hotwords:
params['hotwords'] = request.hotwords
content_type = request.mime_type or 'application/octet-stream'
files = {
'audio_file': (request.filename, request.audio_file, content_type)
}
# Configure timeout: generous values for large file uploads
# Write timeout needs to be high too - large files take time to upload
timeout = httpx.Timeout(
None,
connect=60.0,
read=float(self.timeout),
write=float(self.timeout),
pool=None
)
logger.info(f"Sending ASR request to {url} with params: {params} (timeout: {self.timeout}s)")
with httpx.Client() as client:
response = client.post(url, params=params, files=files, timeout=timeout)
logger.info(f"ASR request completed with status: {response.status_code}")
response.raise_for_status()
# Parse the JSON response
response_text = response.text
try:
data = response.json()
except Exception as json_err:
if response_text.strip().startswith('<'):
logger.error(f"ASR returned HTML error page (status {response.status_code})")
raise ProviderError(
f"ASR service returned HTML error page",
provider=self.PROVIDER_NAME,
status_code=response.status_code
)
else:
raise ProviderError(
f"ASR service returned invalid response: {json_err}",
provider=self.PROVIDER_NAME,
status_code=response.status_code
)
return self._parse_response(data)
except httpx.HTTPStatusError as e:
logger.error(f"ASR request failed with status {e.response.status_code}")
raise ProviderError(
f"ASR request failed with status {e.response.status_code}",
provider=self.PROVIDER_NAME,
status_code=e.response.status_code
) from e
except httpx.TimeoutException as e:
logger.error(f"ASR request timed out after {self.timeout}s")
raise TranscriptionError(f"ASR request timed out after {self.timeout}s") from e
except Exception as e:
error_msg = str(e)
logger.error(f"ASR transcription failed: {error_msg}")
raise TranscriptionError(f"ASR transcription failed: {error_msg}") from e
def _parse_response(self, data: Dict[str, Any]) -> TranscriptionResponse:
"""
Parse ASR webservice response into standardized format.
The ASR response contains:
- text: Full transcription text
- language: Detected language
- segments: Array of segments with speaker, text, start, end
- speaker_embeddings: Optional speaker embeddings (WhisperX only)
"""
segments = []
speakers = set()
full_text_parts = []
last_speaker = None
logger.info(f"ASR response keys: {list(data.keys())}")
if 'segments' in data and isinstance(data['segments'], list):
logger.info(f"Number of segments: {len(data['segments'])}")
for seg in data['segments']:
speaker = seg.get('speaker')
# Handle missing speakers by carrying forward from previous segment
if speaker is None:
if last_speaker is not None:
speaker = last_speaker
else:
speaker = 'UNKNOWN_SPEAKER'
else:
last_speaker = speaker
text = seg.get('text', '').strip()
speakers.add(speaker)
full_text_parts.append(f"[{speaker}]: {text}")
segments.append(TranscriptionSegment(
text=text,
speaker=speaker,
start_time=seg.get('start'),
end_time=seg.get('end')
))
# Get the full text
if 'text' in data and isinstance(data['text'], str):
full_text = data['text']
elif full_text_parts:
full_text = '\n'.join(full_text_parts)
else:
full_text = ''
# Extract speaker embeddings if present
speaker_embeddings = data.get('speaker_embeddings')
if speaker_embeddings:
logger.info(f"Received speaker embeddings for speakers: {list(speaker_embeddings.keys())}")
logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}")
return TranscriptionResponse(
text=full_text,
segments=segments,
speakers=sorted(list(speakers)),
speaker_embeddings=speaker_embeddings,
language=data.get('language'),
provider=self.PROVIDER_NAME,
model="asr-endpoint",
raw_response=data
)
def health_check(self) -> bool:
"""Check if ASR endpoint is reachable."""
try:
with httpx.Client(timeout=10.0) as client:
# Try common health check endpoints
for endpoint in ['/health', '/']:
try:
response = client.get(f"{self.base_url}{endpoint}")
if response.status_code < 500:
return True
except Exception:
continue
return False
except Exception:
return False
@classmethod
def get_config_schema(cls) -> Dict[str, Any]:
"""Return JSON schema for configuration."""
return {
"type": "object",
"required": ["base_url"],
"properties": {
"base_url": {
"type": "string",
"description": "ASR service base URL (e.g., http://whisper-asr:9000)"
},
"timeout": {
"type": "integer",
"default": 1800,
"description": "Request timeout in seconds"
},
"diarize": {
"type": "boolean",
"default": True,
"description": "Enable speaker diarization by default"
},
"return_speaker_embeddings": {
"type": "boolean",
"default": False,
"description": "Request speaker embeddings (WhisperX only)"
}
}
}