""" ASR Endpoint connector for custom self-hosted ASR services. Supports whisper-asr-webservice, WhisperX, and other compatible ASR services that expose a /asr endpoint. """ import logging import os import httpx from typing import Dict, Any, Set, Optional from ..base import ( BaseTranscriptionConnector, TranscriptionCapability, TranscriptionRequest, TranscriptionResponse, TranscriptionSegment, ConnectorSpecifications, ) from ..exceptions import TranscriptionError, ConfigurationError, ProviderError from src.config.app_config import ASR_ENABLE_CHUNKING, ASR_MAX_DURATION_SECONDS logger = logging.getLogger(__name__) class ASREndpointConnector(BaseTranscriptionConnector): """Connector for custom ASR webservice (whisper-asr-webservice, WhisperX, etc.).""" CAPABILITIES: Set[TranscriptionCapability] = { TranscriptionCapability.DIARIZATION, TranscriptionCapability.TIMESTAMPS, TranscriptionCapability.LANGUAGE_DETECTION, TranscriptionCapability.SPEAKER_COUNT_CONTROL, # Supports min/max speakers } PROVIDER_NAME = "asr_endpoint" # SPECIFICATIONS is set dynamically in __init__ based on ASR_ENABLE_CHUNKING config # Default values here for class-level reference (overridden per-instance) SPECIFICATIONS = ConnectorSpecifications( max_file_size_bytes=None, max_duration_seconds=None, handles_chunking_internally=True, ) def __init__(self, config: Dict[str, Any]): """ Initialize the ASR Endpoint connector. Args: config: Configuration dict with keys: - base_url: ASR service base URL (required) - timeout: Request timeout in seconds (default: 1800) - return_speaker_embeddings: Whether to request embeddings (default: False) - diarize: Whether to enable diarization by default (default: True) """ super().__init__(config) self.base_url = config['base_url'].rstrip('/') self._config_timeout = config.get('timeout', 1800) # 30 minutes default self.return_embeddings = config.get('return_speaker_embeddings', False) self.default_diarize = config.get('diarize', True) # Configure chunking behavior based on environment variables # ASR_ENABLE_CHUNKING=true enables app-level chunking for self-hosted ASR services # that may crash on long files due to GPU memory exhaustion if ASR_ENABLE_CHUNKING: # Calculate recommended chunk size (80% of max for safety margin) recommended_chunk = int(ASR_MAX_DURATION_SECONDS * 0.8) self.SPECIFICATIONS = ConnectorSpecifications( max_file_size_bytes=None, # No file size limit max_duration_seconds=ASR_MAX_DURATION_SECONDS, handles_chunking_internally=False, # App handles chunking recommended_chunk_seconds=recommended_chunk, ) logger.info( f"ASR chunking enabled: max_duration={ASR_MAX_DURATION_SECONDS}s, " f"recommended_chunk={recommended_chunk}s" ) else: # Default behavior: ASR service handles everything internally self.SPECIFICATIONS = ConnectorSpecifications( max_file_size_bytes=None, max_duration_seconds=None, handles_chunking_internally=True, ) # Add speaker embeddings capability if enabled if self.return_embeddings: self.CAPABILITIES = self.CAPABILITIES | {TranscriptionCapability.SPEAKER_EMBEDDINGS} @property def timeout(self): """Get ASR timeout, reading fresh from env/DB each time to respect runtime changes.""" # Environment variables take priority env_timeout = os.environ.get('ASR_TIMEOUT') or os.environ.get('asr_timeout_seconds') if env_timeout: try: return int(env_timeout) except (ValueError, TypeError): pass # Try database setting (Admin UI) try: from src.models import SystemSetting db_timeout = SystemSetting.get_setting('asr_timeout_seconds', None) if db_timeout is not None: return int(db_timeout) except Exception: pass # Fall back to config value from initialization return self._config_timeout def _validate_config(self) -> None: """Validate required configuration.""" if not self.config.get('base_url'): raise ConfigurationError("base_url is required for ASR endpoint connector") def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse: """ Transcribe audio using ASR webservice. Args: request: Standardized transcription request Returns: TranscriptionResponse with segments and speaker information """ try: url = f"{self.base_url}/asr" params = { 'encode': True, 'task': 'transcribe', 'output': 'json' } if request.language: params['language'] = request.language logger.info(f"Using transcription language: {request.language}") # Determine if we should diarize should_diarize = request.diarize if request.diarize is not None else self.default_diarize # Send both parameter names for compatibility: # - 'diarize' is used by whisper-asr-webservice # - 'enable_diarization' is used by WhisperX params['diarize'] = should_diarize params['enable_diarization'] = should_diarize if should_diarize and self.return_embeddings: params['return_speaker_embeddings'] = True if request.min_speakers: params['min_speakers'] = request.min_speakers if request.max_speakers: params['max_speakers'] = request.max_speakers if request.prompt: params['initial_prompt'] = request.prompt if request.hotwords: params['hotwords'] = request.hotwords content_type = request.mime_type or 'application/octet-stream' files = { 'audio_file': (request.filename, request.audio_file, content_type) } # Configure timeout: generous values for large file uploads # Write timeout needs to be high too - large files take time to upload timeout = httpx.Timeout( None, connect=60.0, read=float(self.timeout), write=float(self.timeout), pool=None ) logger.info(f"Sending ASR request to {url} with params: {params} (timeout: {self.timeout}s)") with httpx.Client() as client: response = client.post(url, params=params, files=files, timeout=timeout) logger.info(f"ASR request completed with status: {response.status_code}") response.raise_for_status() # Parse the JSON response response_text = response.text try: data = response.json() except Exception as json_err: if response_text.strip().startswith('<'): logger.error(f"ASR returned HTML error page (status {response.status_code})") raise ProviderError( f"ASR service returned HTML error page", provider=self.PROVIDER_NAME, status_code=response.status_code ) else: raise ProviderError( f"ASR service returned invalid response: {json_err}", provider=self.PROVIDER_NAME, status_code=response.status_code ) return self._parse_response(data) except httpx.HTTPStatusError as e: logger.error(f"ASR request failed with status {e.response.status_code}") raise ProviderError( f"ASR request failed with status {e.response.status_code}", provider=self.PROVIDER_NAME, status_code=e.response.status_code ) from e except httpx.TimeoutException as e: logger.error(f"ASR request timed out after {self.timeout}s") raise TranscriptionError(f"ASR request timed out after {self.timeout}s") from e except Exception as e: error_msg = str(e) logger.error(f"ASR transcription failed: {error_msg}") raise TranscriptionError(f"ASR transcription failed: {error_msg}") from e def _parse_response(self, data: Dict[str, Any]) -> TranscriptionResponse: """ Parse ASR webservice response into standardized format. The ASR response contains: - text: Full transcription text - language: Detected language - segments: Array of segments with speaker, text, start, end - speaker_embeddings: Optional speaker embeddings (WhisperX only) """ segments = [] speakers = set() full_text_parts = [] last_speaker = None logger.info(f"ASR response keys: {list(data.keys())}") if 'segments' in data and isinstance(data['segments'], list): logger.info(f"Number of segments: {len(data['segments'])}") for seg in data['segments']: speaker = seg.get('speaker') # Handle missing speakers by carrying forward from previous segment if speaker is None: if last_speaker is not None: speaker = last_speaker else: speaker = 'UNKNOWN_SPEAKER' else: last_speaker = speaker text = seg.get('text', '').strip() speakers.add(speaker) full_text_parts.append(f"[{speaker}]: {text}") segments.append(TranscriptionSegment( text=text, speaker=speaker, start_time=seg.get('start'), end_time=seg.get('end') )) # Get the full text if 'text' in data and isinstance(data['text'], str): full_text = data['text'] elif full_text_parts: full_text = '\n'.join(full_text_parts) else: full_text = '' # Extract speaker embeddings if present speaker_embeddings = data.get('speaker_embeddings') if speaker_embeddings: logger.info(f"Received speaker embeddings for speakers: {list(speaker_embeddings.keys())}") logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}") return TranscriptionResponse( text=full_text, segments=segments, speakers=sorted(list(speakers)), speaker_embeddings=speaker_embeddings, language=data.get('language'), provider=self.PROVIDER_NAME, model="asr-endpoint", raw_response=data ) def health_check(self) -> bool: """Check if ASR endpoint is reachable.""" try: with httpx.Client(timeout=10.0) as client: # Try common health check endpoints for endpoint in ['/health', '/']: try: response = client.get(f"{self.base_url}{endpoint}") if response.status_code < 500: return True except Exception: continue return False except Exception: return False @classmethod def get_config_schema(cls) -> Dict[str, Any]: """Return JSON schema for configuration.""" return { "type": "object", "required": ["base_url"], "properties": { "base_url": { "type": "string", "description": "ASR service base URL (e.g., http://whisper-asr:9000)" }, "timeout": { "type": "integer", "default": 1800, "description": "Request timeout in seconds" }, "diarize": { "type": "boolean", "default": True, "description": "Enable speaker diarization by default" }, "return_speaker_embeddings": { "type": "boolean", "default": False, "description": "Request speaker embeddings (WhisperX only)" } } }