Initial release: DictIA v0.8.14-alpha (fork de Speakr, AGPL-3.0)

This commit is contained in:
InnovA AI
2026-03-16 21:47:37 +00:00
commit 42772a31ed
365 changed files with 103572 additions and 0 deletions

34
src/services/__init__.py Normal file
View File

@@ -0,0 +1,34 @@
"""
Service layer for business logic.
"""
from .embeddings import *
from .llm import *
from .document import *
from .retention import *
__all__ = [
# Embedding services
'get_embedding_model',
'chunk_transcription',
'generate_embeddings',
'serialize_embedding',
'deserialize_embedding',
'get_accessible_recording_ids',
'process_recording_chunks',
'basic_text_search_chunks',
'semantic_search_chunks',
# LLM services
'is_gpt5_model',
'is_using_openai_api',
'call_llm_completion',
'call_chat_completion',
'chat_client',
'format_api_error_message',
# Document services
'process_markdown_to_docx',
# Retention services
'is_recording_exempt_from_deletion',
'get_retention_days_for_recording',
'process_auto_deletion',
]

211
src/services/audit.py Normal file
View File

@@ -0,0 +1,211 @@
"""
Central audit service for Loi 25 compliance.
Wraps AccessLog and AuthLog with request context helpers.
All logging is gated behind ENABLE_AUDIT_LOG env var.
NOTE: ENABLE_AUDIT_LOG is read once at import time. Changing the env var
requires an application restart to take effect.
NOTE: Audit helpers do NOT commit the session — the caller's transaction
will persist the log entry when it commits. This avoids interfering with
ongoing transactions (e.g. a delete operation that audits then deletes).
"""
import os
import logging
from datetime import datetime, timedelta
from flask import request, has_request_context
from flask_login import current_user
from src.database import db
from src.models.access_log import AccessLog
from src.models.auth_log import AuthLog
logger = logging.getLogger(__name__)
# Read once at import — requires restart to change.
ENABLE_AUDIT_LOG = os.environ.get('ENABLE_AUDIT_LOG', 'false').lower() == 'true'
def is_audit_enabled():
"""Check if audit logging is enabled."""
return ENABLE_AUDIT_LOG
def _get_request_context():
"""Extract IP address and user agent from current request."""
if not has_request_context():
return None, None
ip = request.remote_addr
ua = request.headers.get('User-Agent', '')[:500]
return ip, ua
def _get_current_user_id():
"""Get current user ID if authenticated."""
if has_request_context() and current_user and current_user.is_authenticated:
return current_user.id
return None
# --- Access logging helpers ---
_DEDUP_WINDOW_SECONDS = 300 # 5 minutes
def _is_recent_duplicate(action, resource_type, resource_id, user_id):
"""Return True if the same user already logged this action on this resource in the last 5 min.
Prevents unbounded log growth for high-frequency read operations (e.g. view on every GET).
"""
if user_id is None or resource_id is None:
return False
cutoff = datetime.utcnow() - timedelta(seconds=_DEDUP_WINDOW_SECONDS)
return AccessLog.query.filter(
AccessLog.user_id == user_id,
AccessLog.action == action,
AccessLog.resource_type == resource_type,
AccessLog.resource_id == resource_id,
AccessLog.timestamp >= cutoff,
).first() is not None
def audit_access(action, resource_type, resource_id=None, user_id=None, status='success', details=None):
"""Log a data access event if audit is enabled.
Does NOT commit — the log is added to the current session and will be
persisted when the caller (or Flask teardown) commits.
View events are deduplicated: the same user viewing the same resource within
5 minutes is logged only once to avoid unbounded log growth.
"""
if not ENABLE_AUDIT_LOG:
return None
try:
ip, ua = _get_request_context()
if user_id is None:
user_id = _get_current_user_id()
if action == 'view' and _is_recent_duplicate(action, resource_type, resource_id, user_id):
return None
log = AccessLog.log_access(
action=action,
resource_type=resource_type,
resource_id=resource_id,
user_id=user_id,
status=status,
details=details,
ip_address=ip,
user_agent=ua,
)
return log
except Exception as e:
logger.warning(f"Failed to write access audit log: {e}")
return None
def audit_view(resource_type, resource_id, **kwargs):
"""Log a view access."""
return audit_access('view', resource_type, resource_id, **kwargs)
def audit_download(resource_type, resource_id, **kwargs):
"""Log a download access."""
return audit_access('download', resource_type, resource_id, **kwargs)
def audit_edit(resource_type, resource_id, **kwargs):
"""Log an edit."""
return audit_access('edit', resource_type, resource_id, **kwargs)
def audit_delete(resource_type, resource_id, **kwargs):
"""Log a deletion."""
return audit_access('delete', resource_type, resource_id, **kwargs)
def audit_export(resource_type, resource_id, **kwargs):
"""Log a data export."""
return audit_access('export', resource_type, resource_id, **kwargs)
# --- Auth logging helpers ---
# Auth events (login/logout/failed) are standalone operations, so they
# commit their own transaction since the caller typically redirects after.
def _audit_auth(func, *args, **kwargs):
"""Wrapper for auth audit helpers that commits independently."""
if not ENABLE_AUDIT_LOG:
return None
try:
ip, ua = _get_request_context()
log = func(*args, ip_address=ip, user_agent=ua, **kwargs)
db.session.commit()
return log
except Exception as e:
logger.warning(f"Failed to write auth audit log: {e}")
db.session.rollback()
return None
def audit_login(user_id, details=None):
"""Log a successful login."""
return _audit_auth(AuthLog.log_login, user_id, details=details)
def audit_logout(user_id=None):
"""Log a logout."""
if user_id is None:
user_id = _get_current_user_id()
return _audit_auth(AuthLog.log_logout, user_id)
def audit_failed_login(details=None):
"""Log a failed login."""
return _audit_auth(AuthLog.log_failed_login, details=details)
def audit_register(user_id):
"""Log a registration."""
return _audit_auth(AuthLog.log_register, user_id)
def audit_password_change(user_id, details=None):
"""Log a password change."""
return _audit_auth(AuthLog.log_password_change, user_id, details=details)
def audit_password_reset(user_id):
"""Log a password reset."""
return _audit_auth(AuthLog.log_password_reset, user_id)
def audit_sso_login(user_id, details=None):
"""Log an SSO login."""
return _audit_auth(AuthLog.log_sso_login, user_id, details=details)
# --- Query helpers for admin ---
def get_access_logs(page=1, per_page=50, user_id=None, resource_type=None, resource_id=None, action=None):
"""Query access logs with pagination and filters."""
query = AccessLog.query.order_by(AccessLog.timestamp.desc())
if user_id is not None:
query = query.filter_by(user_id=user_id)
if resource_type is not None:
query = query.filter_by(resource_type=resource_type)
if resource_id is not None:
query = query.filter_by(resource_id=resource_id)
if action is not None:
query = query.filter_by(action=action)
return query.paginate(page=page, per_page=per_page, error_out=False)
def get_auth_logs(page=1, per_page=50, user_id=None, action=None):
"""Query auth logs with pagination and filters."""
query = AuthLog.query.order_by(AuthLog.timestamp.desc())
if user_id is not None:
query = query.filter_by(user_id=user_id)
if action is not None:
query = query.filter_by(action=action)
return query.paginate(page=page, per_page=per_page, error_out=False)

102
src/services/calendar.py Normal file
View File

@@ -0,0 +1,102 @@
"""
Calendar/ICS file generation services.
"""
import json
import uuid
from datetime import datetime, timedelta
def generate_ics_content(event):
"""Generate ICS calendar file content for an event."""
import uuid
from datetime import datetime, timedelta
# Generate unique ID for the event
uid = f"{event.id}-{uuid.uuid4()}@speakr.app"
# Format dates in iCalendar format (YYYYMMDDTHHMMSS)
def format_ical_date(dt):
if dt:
return dt.strftime('%Y%m%dT%H%M%S')
return None
# Start building ICS content
lines = [
'BEGIN:VCALENDAR',
'VERSION:2.0',
'PRODID:-//Speakr//Event Export//EN',
'CALSCALE:GREGORIAN',
'METHOD:PUBLISH',
'BEGIN:VEVENT',
f'UID:{uid}',
f'DTSTAMP:{format_ical_date(datetime.utcnow())}',
]
# Add event details
if event.start_datetime:
lines.append(f'DTSTART:{format_ical_date(event.start_datetime)}')
if event.end_datetime:
lines.append(f'DTEND:{format_ical_date(event.end_datetime)}')
elif event.start_datetime:
# If no end time, default to 1 hour after start
end_time = event.start_datetime + timedelta(hours=1)
lines.append(f'DTEND:{format_ical_date(end_time)}')
# Add title and description
lines.append(f'SUMMARY:{escape_ical_text(event.title)}')
if event.description:
lines.append(f'DESCRIPTION:{escape_ical_text(event.description)}')
# Add location if available
if event.location:
lines.append(f'LOCATION:{escape_ical_text(event.location)}')
# Add attendees if available
if event.attendees:
try:
attendees_list = json.loads(event.attendees)
for attendee in attendees_list:
if attendee:
lines.append(f'ATTENDEE:CN={escape_ical_text(attendee)}:mailto:{attendee.replace(" ", ".").lower()}@example.com')
except:
pass
# Add reminder/alarm if specified
if event.reminder_minutes and event.reminder_minutes > 0:
lines.extend([
'BEGIN:VALARM',
'TRIGGER:-PT{}M'.format(event.reminder_minutes),
'ACTION:DISPLAY',
f'DESCRIPTION:Reminder: {escape_ical_text(event.title)}',
'END:VALARM'
])
# Close event and calendar
lines.extend([
'STATUS:CONFIRMED',
'TRANSP:OPAQUE',
'END:VEVENT',
'END:VCALENDAR'
])
return '\r\n'.join(lines)
def escape_ical_text(text):
"""Escape special characters for iCalendar format."""
if not text:
return ''
# Escape special characters
text = str(text)
text = text.replace('\\', '\\\\')
text = text.replace(',', '\\,')
text = text.replace(';', '\\;')
text = text.replace('\n', '\\n')
return text

296
src/services/document.py Normal file
View File

@@ -0,0 +1,296 @@
"""
Document processing and conversion services.
"""
import re
from docx import Document
from docx.shared import Pt, RGBColor
def process_markdown_to_docx(doc, content):
"""Convert markdown content to properly formatted Word document elements.
Supports:
- Tables (markdown pipe tables)
- Headings (# ## ###)
- Bold text (**text**)
- Italic text (*text* or _text_)
- Bold italic (***text***)
- Inline code (`code`)
- Code blocks (```code```)
- Strikethrough (~~text~~)
- Links ([text](url))
- Bullet lists (- or *)
- Numbered lists (1. 2. 3.)
- Horizontal rules (--- or ***)
"""
from docx.shared import RGBColor, Pt
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
from docx.oxml.ns import qn
import re
def ensure_unicode_font(run, text):
"""Ensure the run uses a font that supports the characters in the text."""
# Check if text contains non-ASCII characters
try:
text.encode('ascii')
# Text is pure ASCII, no special font needed
except UnicodeEncodeError:
# Text contains non-ASCII characters, use a font with better Unicode support
# Use Arial for broad compatibility - it has good Unicode support on most systems
run.font.name = 'Arial'
# Set the East Asian font for CJK (Chinese, Japanese, Korean) text
# This ensures proper rendering in Word
r = run._element
r.rPr.rFonts.set(qn('w:eastAsia'), 'Arial')
return run
def add_formatted_run(paragraph, text):
"""Add a run with inline formatting to a paragraph."""
if not text:
return
# Pattern for all inline formatting
# Order matters: check triple asterisk before double/single
patterns = [
(r'\*\*\*(.*?)\*\*\*', lambda p, t: (lambda r: (setattr(r, 'bold', True), setattr(r, 'italic', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Bold italic
(r'\*\*(.*?)\*\*', lambda p, t: (lambda r: (setattr(r, 'bold', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Bold
(r'(?<!\*)\*(?!\*)(.*?)\*(?!\*)', lambda p, t: (lambda r: (setattr(r, 'italic', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Italic with *
(r'\b_(.*?)_\b', lambda p, t: (lambda r: (setattr(r, 'italic', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Italic with _
(r'~~(.*?)~~', lambda p, t: (lambda r: (setattr(r, 'strike', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Strikethrough
(r'`([^`]+)`', lambda p, t: add_code_run(p, t)), # Inline code
(r'\[([^\]]+)\]\(([^)]+)\)', lambda p, t, u: add_link_run(p, t, u)), # Links
]
def add_code_run(para, text):
"""Add inline code with monospace font and background."""
run = para.add_run(text)
run.font.name = 'Courier New'
run.font.size = Pt(10)
run.font.color.rgb = RGBColor(220, 20, 60) # Crimson color for code
# Check if we need Unicode support for code
try:
text.encode('ascii')
except UnicodeEncodeError:
# Use Consolas as fallback for better Unicode support in monospace
r = run._element
r.rPr.rFonts.set(qn('w:eastAsia'), 'Consolas')
return run
def add_link_run(para, text, url):
"""Add a hyperlink-styled run (note: actual hyperlinks require more complex handling)."""
full_text = f"{text} ({url})"
run = para.add_run(full_text)
run.font.color.rgb = RGBColor(0, 0, 255) # Blue color for links
run.font.underline = True
ensure_unicode_font(run, full_text)
return run
# Process the text with all patterns
remaining_text = text
while remaining_text:
earliest_match = None
earliest_pos = len(remaining_text)
matched_pattern = None
# Find the earliest matching pattern
for pattern, handler in patterns:
match = re.search(pattern, remaining_text)
if match and match.start() < earliest_pos:
earliest_match = match
earliest_pos = match.start()
matched_pattern = handler
if earliest_match:
# Add text before the match
if earliest_pos > 0:
run = paragraph.add_run(remaining_text[:earliest_pos])
ensure_unicode_font(run, remaining_text[:earliest_pos])
# Apply formatting for the matched text
if '[' in earliest_match.group(0) and '](' in earliest_match.group(0):
# Special handling for links (two groups)
matched_pattern(paragraph, earliest_match.group(1), earliest_match.group(2))
else:
matched_pattern(paragraph, earliest_match.group(1))
# Continue with remaining text
remaining_text = remaining_text[earliest_match.end():]
else:
# No more patterns, add the rest as plain text
run = paragraph.add_run(remaining_text)
ensure_unicode_font(run, remaining_text)
break
def parse_table(lines, start_idx):
"""Parse a markdown table starting at the given index."""
if start_idx >= len(lines):
return None, start_idx
# Check if this looks like a table
if '|' not in lines[start_idx]:
return None, start_idx
table_data = []
idx = start_idx
while idx < len(lines) and '|' in lines[idx]:
# Skip separator lines
if re.match(r'^[\s\|\-:]+$', lines[idx]):
idx += 1
continue
# Parse cells
cells = [cell.strip() for cell in lines[idx].split('|')]
# Remove empty cells at start and end
if cells and not cells[0]:
cells = cells[1:]
if cells and not cells[-1]:
cells = cells[:-1]
if cells:
table_data.append(cells)
idx += 1
if table_data:
return table_data, idx
return None, start_idx
# Split content into lines
lines = content.split('\n')
i = 0
in_code_block = False
code_block_content = []
while i < len(lines):
line = lines[i]
# Handle code blocks
if line.strip().startswith('```'):
if not in_code_block:
in_code_block = True
code_block_content = []
else:
# End of code block - add it as preformatted text
in_code_block = False
if code_block_content:
p = doc.add_paragraph()
p.style = 'Normal'
code_text = '\n'.join(code_block_content)
run = p.add_run(code_text)
run.font.name = 'Courier New'
run.font.size = Pt(10)
run.font.color.rgb = RGBColor(64, 64, 64)
# Check if we need Unicode support for code blocks
try:
code_text.encode('ascii')
except UnicodeEncodeError:
r = run._element
r.rPr.rFonts.set(qn('w:eastAsia'), 'Consolas')
i += 1
continue
if in_code_block:
code_block_content.append(line)
i += 1
continue
# Check for table
table_data, end_idx = parse_table(lines, i)
if table_data:
# Create Word table
table = doc.add_table(rows=len(table_data), cols=len(table_data[0]))
table.style = 'Table Grid'
# Populate table
for row_idx, row_data in enumerate(table_data):
for col_idx, cell_text in enumerate(row_data):
if col_idx < len(table.rows[row_idx].cells):
cell = table.rows[row_idx].cells[col_idx]
# Clear existing paragraphs and add new one
cell.text = ""
p = cell.add_paragraph()
add_formatted_run(p, cell_text)
# Make header row bold
if row_idx == 0:
for run in p.runs:
run.bold = True
doc.add_paragraph('') # Space after table
i = end_idx
continue
line = line.rstrip()
# Skip empty lines
if not line:
doc.add_paragraph('')
i += 1
continue
# Horizontal rule
if re.match(r'^(\*{3,}|-{3,}|_{3,})$', line.strip()):
p = doc.add_paragraph('' * 50)
p.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
i += 1
continue
# Headings
if line.startswith('# '):
doc.add_heading(line[2:], 1)
elif line.startswith('## '):
doc.add_heading(line[3:], 2)
elif line.startswith('### '):
doc.add_heading(line[4:], 3)
elif line.startswith('#### '):
doc.add_heading(line[5:], 4)
# Bullet points
elif line.lstrip().startswith('- ') or line.lstrip().startswith('* '):
# Get the indentation level
indent = len(line) - len(line.lstrip())
bullet_text = line.lstrip()[2:]
p = doc.add_paragraph(style='List Bullet')
# Add indentation if nested
if indent > 0:
p.paragraph_format.left_indent = Pt(indent * 10)
add_formatted_run(p, bullet_text)
# Numbered lists
elif re.match(r'^\s*\d+\.', line):
match = re.match(r'^(\s*)(\d+)\.\s*(.*)', line)
if match:
indent = len(match.group(1))
list_text = match.group(3)
p = doc.add_paragraph(style='List Number')
if indent > 0:
p.paragraph_format.left_indent = Pt(indent * 10)
add_formatted_run(p, list_text)
# Blockquote
elif line.startswith('> '):
p = doc.add_paragraph()
p.paragraph_format.left_indent = Pt(30)
add_formatted_run(p, line[2:])
# Add a gray color to indicate quote
for run in p.runs:
run.font.color.rgb = RGBColor(100, 100, 100)
else:
# Regular paragraph
p = doc.add_paragraph()
add_formatted_run(p, line)
i += 1
# --- Database Models ---
# --- Database Models ---
# Models have been extracted to src/models/ and imported at the top of this file
# --- Forms for Authentication ---
# --- Custom Password Validator ---
# password_check utility has been extracted to src/utils/security.py
# --- Blueprint Registration ---
# Import and register all blueprints for modular route organization

443
src/services/email.py Normal file
View File

@@ -0,0 +1,443 @@
"""
Email service for verification and password reset.
This module provides email functionality using Python's built-in smtplib.
All email features are opt-in via environment variables.
"""
import os
import smtplib
import logging
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from typing import Optional
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
from flask import current_app, url_for
logger = logging.getLogger(__name__)
# Token expiry times
EMAIL_VERIFICATION_EXPIRY = 24 * 60 * 60 # 24 hours in seconds
PASSWORD_RESET_EXPIRY = 1 * 60 * 60 # 1 hour in seconds
def get_email_config():
"""Get email configuration from environment variables."""
return {
'enabled': os.environ.get('ENABLE_EMAIL_VERIFICATION', 'false').lower() == 'true',
'required': os.environ.get('REQUIRE_EMAIL_VERIFICATION', 'false').lower() == 'true',
'smtp_host': os.environ.get('SMTP_HOST', ''),
'smtp_port': int(os.environ.get('SMTP_PORT', '587')),
'smtp_username': os.environ.get('SMTP_USERNAME', ''),
'smtp_password': os.environ.get('SMTP_PASSWORD', ''),
'smtp_use_tls': os.environ.get('SMTP_USE_TLS', 'true').lower() == 'true',
'smtp_use_ssl': os.environ.get('SMTP_USE_SSL', 'false').lower() == 'true',
'from_address': os.environ.get('SMTP_FROM_ADDRESS', 'noreply@yourdomain.com'),
'from_name': os.environ.get('SMTP_FROM_NAME', 'Speakr'),
}
def is_email_verification_enabled() -> bool:
"""Check if email verification is enabled."""
return get_email_config()['enabled']
def is_email_verification_required() -> bool:
"""Check if email verification is required for login."""
config = get_email_config()
return config['enabled'] and config['required']
def is_smtp_configured() -> bool:
"""Check if SMTP settings are properly configured."""
config = get_email_config()
return bool(config['smtp_host'] and config['smtp_username'] and config['smtp_password'])
def get_serializer(salt: str) -> URLSafeTimedSerializer:
"""Get a URL-safe timed serializer for token generation."""
secret_key = current_app.config.get('SECRET_KEY', 'default-dev-key')
return URLSafeTimedSerializer(secret_key, salt=salt)
def generate_verification_token(user_id: int) -> str:
"""Generate an email verification token."""
serializer = get_serializer('email-verification')
return serializer.dumps(user_id)
def generate_password_reset_token(user_id: int) -> str:
"""Generate a password reset token."""
serializer = get_serializer('password-reset')
return serializer.dumps(user_id)
def verify_email_token(token: str) -> Optional[int]:
"""
Verify an email verification token.
Returns the user_id if valid, None otherwise.
"""
serializer = get_serializer('email-verification')
try:
user_id = serializer.loads(token, max_age=EMAIL_VERIFICATION_EXPIRY)
return user_id
except SignatureExpired:
logger.warning("Email verification token expired")
return None
except BadSignature:
logger.warning("Invalid email verification token")
return None
def verify_reset_token(token: str) -> Optional[int]:
"""
Verify a password reset token.
Returns the user_id if valid, None otherwise.
"""
serializer = get_serializer('password-reset')
try:
user_id = serializer.loads(token, max_age=PASSWORD_RESET_EXPIRY)
return user_id
except SignatureExpired:
logger.warning("Password reset token expired")
return None
except BadSignature:
logger.warning("Invalid password reset token")
return None
def _send_email(to_email: str, subject: str, html_body: str, text_body: str = None) -> bool:
"""
Send an email using SMTP.
Returns True if successful, False otherwise.
"""
config = get_email_config()
if not is_smtp_configured():
logger.error("SMTP is not configured. Cannot send email.")
return False
try:
msg = MIMEMultipart('alternative')
msg['Subject'] = subject
msg['From'] = f"{config['from_name']} <{config['from_address']}>"
msg['To'] = to_email
# Add plain text version
if text_body:
part1 = MIMEText(text_body, 'plain')
msg.attach(part1)
# Add HTML version
part2 = MIMEText(html_body, 'html')
msg.attach(part2)
# Connect to SMTP server
if config['smtp_use_ssl']:
server = smtplib.SMTP_SSL(config['smtp_host'], config['smtp_port'])
else:
server = smtplib.SMTP(config['smtp_host'], config['smtp_port'])
if config['smtp_use_tls']:
server.starttls()
server.login(config['smtp_username'], config['smtp_password'])
server.sendmail(config['from_address'], to_email, msg.as_string())
server.quit()
logger.info(f"Email sent successfully to {to_email}")
return True
except smtplib.SMTPAuthenticationError as e:
logger.error(f"SMTP authentication failed: {e}")
return False
except smtplib.SMTPException as e:
logger.error(f"SMTP error sending email: {e}")
return False
except Exception as e:
logger.error(f"Error sending email: {e}")
return False
def _get_email_template(content_html: str, content_text: str, subject: str) -> tuple[str, str]:
"""
Wrap content in the Speakr email template.
Returns (html_body, text_body)
"""
# Get the base URL for the logo
try:
logo_url = url_for('static', filename='img/icon-192x192.png', _external=True)
except RuntimeError:
# Outside of request context, use a placeholder
logo_url = ""
html_body = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; line-height: 1.6; color: #1f2937; margin: 0; padding: 0; background-color: #e8eaed;">
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%" style="background-color: #e8eaed;">
<tr>
<td style="padding: 40px 20px;">
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="600" style="max-width: 600px; margin: 0 auto;">
<!-- Header -->
<tr>
<td style="background-color: #2563eb; padding: 32px 40px; border-radius: 12px 12px 0 0;">
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%">
<tr>
<td>
<!-- Logo and Brand -->
<table role="presentation" cellspacing="0" cellpadding="0" border="0">
<tr>
<td style="vertical-align: middle; padding-right: 12px;">
<img src="{logo_url}" alt="Speakr" width="44" height="44" style="display: block; border-radius: 8px;">
</td>
<td style="vertical-align: middle;">
<h1 style="color: #ffffff; margin: 0; font-size: 28px; font-weight: 700; letter-spacing: -0.5px;">Speakr</h1>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td style="padding-top: 8px;">
<p style="color: rgba(255,255,255,0.85); margin: 0; font-size: 14px;">AI-Powered Audio Transcription</p>
</td>
</tr>
</table>
</td>
</tr>
<!-- Content -->
<tr>
<td style="background-color: #ffffff; padding: 40px; border-left: 1px solid #e5e7eb; border-right: 1px solid #e5e7eb;">
{content_html}
</td>
</tr>
<!-- Footer -->
<tr>
<td style="background-color: #f8f9fa; padding: 24px 40px; border-radius: 0 0 12px 12px; border: 1px solid #e5e7eb; border-top: none;">
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%">
<tr>
<td style="text-align: center;">
<p style="color: #6b7280; font-size: 12px; margin: 0 0 8px 0;">
This email was sent by Speakr. If you have questions, please contact your administrator.
</p>
<p style="color: #9ca3af; font-size: 11px; margin: 0;">
&copy; {datetime.utcnow().year} Speakr &middot; AI-Powered Audio Transcription
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>
"""
text_body = f"""
{subject}
{'=' * len(subject)}
{content_text}
---
This email was sent by Speakr - AI-Powered Audio Transcription.
If you have questions, please contact your administrator.
"""
return html_body, text_body
def send_verification_email(user) -> bool:
"""
Send a verification email to a user.
Args:
user: User model instance
Returns True if email was sent successfully, False otherwise.
"""
from src.database import db
if not is_email_verification_enabled():
logger.debug("Email verification is disabled")
return False
if not is_smtp_configured():
logger.warning("Cannot send verification email: SMTP not configured")
return False
# Generate token and store it
token = generate_verification_token(user.id)
user.email_verification_token = token
user.email_verification_sent_at = datetime.utcnow()
db.session.commit()
# Build verification URL
verify_url = url_for('auth.verify_email', token=token, _external=True)
subject = "Verify your email address - Speakr"
content_html = f"""
<h2 style="color: #1f2937; margin: 0 0 24px 0; font-size: 24px; font-weight: 600;">Verify Your Email Address</h2>
<p style="color: #374151; margin: 0 0 16px 0; font-size: 16px;">Hi {user.username},</p>
<p style="color: #374151; margin: 0 0 24px 0; font-size: 16px;">
Welcome to Speakr! To complete your registration and start transcribing your audio recordings, please verify your email address.
</p>
<div style="text-align: center; margin: 32px 0;">
<a href="{verify_url}" style="display: inline-block; background-color: #2563eb; color: #ffffff; text-decoration: none; padding: 14px 32px; border-radius: 8px; font-weight: 600; font-size: 16px;">Verify Email Address</a>
</div>
<p style="color: #6b7280; font-size: 14px; margin: 24px 0 8px 0;">Or copy and paste this link into your browser:</p>
<p style="word-break: break-all; color: #2563eb; font-size: 14px; margin: 0; padding: 12px; background-color: #f3f4f6; border-radius: 6px;">{verify_url}</p>
<div style="margin-top: 32px; padding-top: 24px; border-top: 1px solid #e5e7eb;">
<p style="color: #9ca3af; font-size: 13px; margin: 0;">
<strong>This link will expire in 24 hours.</strong><br>
If you didn't create an account on Speakr, you can safely ignore this email.
</p>
</div>
"""
content_text = f"""Hi {user.username},
Welcome to Speakr! To complete your registration and start transcribing your audio recordings, please verify your email address.
Click here to verify: {verify_url}
This link will expire in 24 hours.
If you didn't create an account on Speakr, you can safely ignore this email."""
html_body, text_body = _get_email_template(content_html, content_text, subject)
return _send_email(user.email, subject, html_body, text_body)
def send_password_reset_email(user) -> bool:
"""
Send a password reset email to a user.
Args:
user: User model instance
Returns True if email was sent successfully, False otherwise.
"""
from src.database import db
if not is_smtp_configured():
logger.warning("Cannot send password reset email: SMTP not configured")
return False
# Generate token and store it
token = generate_password_reset_token(user.id)
user.password_reset_token = token
user.password_reset_sent_at = datetime.utcnow()
db.session.commit()
# Build reset URL
reset_url = url_for('auth.reset_password', token=token, _external=True)
subject = "Reset your password - Speakr"
content_html = f"""
<h2 style="color: #1f2937; margin: 0 0 24px 0; font-size: 24px; font-weight: 600;">Reset Your Password</h2>
<p style="color: #374151; margin: 0 0 16px 0; font-size: 16px;">Hi {user.username},</p>
<p style="color: #374151; margin: 0 0 24px 0; font-size: 16px;">
We received a request to reset your Speakr account password. Click the button below to create a new password.
</p>
<div style="text-align: center; margin: 32px 0;">
<a href="{reset_url}" style="display: inline-block; background-color: #2563eb; color: #ffffff; text-decoration: none; padding: 14px 32px; border-radius: 8px; font-weight: 600; font-size: 16px;">Reset Password</a>
</div>
<p style="color: #6b7280; font-size: 14px; margin: 24px 0 8px 0;">Or copy and paste this link into your browser:</p>
<p style="word-break: break-all; color: #2563eb; font-size: 14px; margin: 0; padding: 12px; background-color: #f3f4f6; border-radius: 6px;">{reset_url}</p>
<div style="margin-top: 32px; padding-top: 24px; border-top: 1px solid #e5e7eb;">
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%">
<tr>
<td style="width: 24px; vertical-align: top; padding-right: 12px;">
<span style="font-size: 18px;">⚠️</span>
</td>
<td>
<p style="color: #9ca3af; font-size: 13px; margin: 0;">
<strong style="color: #6b7280;">This link will expire in 1 hour.</strong><br>
If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged.
</p>
</td>
</tr>
</table>
</div>
"""
content_text = f"""Hi {user.username},
We received a request to reset your Speakr account password. Click the link below to create a new password:
{reset_url}
This link will expire in 1 hour.
If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged."""
html_body, text_body = _get_email_template(content_html, content_text, subject)
return _send_email(user.email, subject, html_body, text_body)
def can_resend_verification(user) -> tuple[bool, Optional[int]]:
"""
Check if a verification email can be resent.
Returns (can_resend, seconds_until_can_resend)
"""
if not user.email_verification_sent_at:
return True, None
# Allow resend after 60 seconds
cooldown = timedelta(seconds=60)
time_since_last = datetime.utcnow() - user.email_verification_sent_at
if time_since_last >= cooldown:
return True, None
remaining = (cooldown - time_since_last).seconds
return False, remaining
def can_resend_password_reset(user) -> tuple[bool, Optional[int]]:
"""
Check if a password reset email can be resent.
Returns (can_resend, seconds_until_can_resend)
"""
if not user.password_reset_sent_at:
return True, None
# Allow resend after 60 seconds
cooldown = timedelta(seconds=60)
time_since_last = datetime.utcnow() - user.password_reset_sent_at
if time_since_last >= cooldown:
return True, None
remaining = (cooldown - time_since_last).seconds
return False, remaining

422
src/services/embeddings.py Normal file
View File

@@ -0,0 +1,422 @@
"""
Embedding generation and semantic search services.
"""
import os
import numpy as np
from flask import current_app
from sqlalchemy.orm import joinedload
try:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
EMBEDDINGS_AVAILABLE = True
except ImportError:
EMBEDDINGS_AVAILABLE = False
cosine_similarity = None
from src.database import db
from src.models import Recording, TranscriptChunk, InternalShare, RecordingTag
ENABLE_INTERNAL_SHARING = os.environ.get('ENABLE_INTERNAL_SHARING', 'false').lower() == 'true'
# Initialize embedding model (lazy loading)
_embedding_model = None
def get_embedding_model():
"""Get or initialize the sentence transformer model."""
global _embedding_model
if not EMBEDDINGS_AVAILABLE:
return None
if _embedding_model is None:
try:
_embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
current_app.logger.info("Embedding model loaded successfully")
except Exception as e:
current_app.logger.error(f"Failed to load embedding model: {e}")
return None
return _embedding_model
def chunk_transcription(transcription, max_chunk_length=500, overlap=50):
"""
Split transcription into overlapping chunks for better context retrieval.
Args:
transcription (str): The full transcription text
max_chunk_length (int): Maximum characters per chunk
overlap (int): Character overlap between chunks
Returns:
list: List of text chunks
"""
if not transcription or len(transcription) <= max_chunk_length:
return [transcription] if transcription else []
chunks = []
start = 0
while start < len(transcription):
end = start + max_chunk_length
# Try to break at sentence boundaries
if end < len(transcription):
# Look for sentence endings within the last 100 characters
sentence_end = -1
for i in range(max(0, end - 100), end):
if transcription[i] in '.!?':
# Check if it's not an abbreviation
if i + 1 < len(transcription) and transcription[i + 1].isspace():
sentence_end = i + 1
if sentence_end > start:
end = sentence_end
chunk = transcription[start:end].strip()
if chunk:
chunks.append(chunk)
# Move start position with overlap
start = max(start + 1, end - overlap)
# Prevent infinite loop
if start >= len(transcription):
break
return chunks
def generate_embeddings(texts):
"""
Generate embeddings for a list of texts.
Args:
texts (list): List of text strings
Returns:
list: List of embedding vectors as numpy arrays, or empty list if embeddings unavailable
"""
if not EMBEDDINGS_AVAILABLE:
current_app.logger.warning("Embeddings not available - skipping embedding generation")
return []
model = get_embedding_model()
if not model or not texts:
return []
try:
embeddings = model.encode(texts)
return [embedding.astype(np.float32) for embedding in embeddings]
except Exception as e:
current_app.logger.error(f"Error generating embeddings: {e}")
return []
def serialize_embedding(embedding):
"""Convert numpy array to binary for database storage."""
if embedding is None or not EMBEDDINGS_AVAILABLE:
return None
return embedding.tobytes()
def deserialize_embedding(binary_data):
"""Convert binary data back to numpy array."""
if binary_data is None or not EMBEDDINGS_AVAILABLE:
return None
return np.frombuffer(binary_data, dtype=np.float32)
def get_accessible_recording_ids(user_id):
"""
Get all recording IDs that a user has access to.
Includes:
- Recordings owned by the user
- Recordings shared with the user via InternalShare
- Recordings shared via group tags (if team membership exists)
Args:
user_id (int): User ID to check access for
Returns:
list: List of recording IDs the user can access
"""
accessible_ids = set()
# 1. User's own recordings
own_recordings = db.session.query(Recording.id).filter_by(user_id=user_id).all()
accessible_ids.update([r.id for r in own_recordings])
# 2. Internally shared recordings
if ENABLE_INTERNAL_SHARING:
shared_recordings = db.session.query(InternalShare.recording_id).filter_by(
shared_with_user_id=user_id
).all()
accessible_ids.update([r.recording_id for r in shared_recordings])
return list(accessible_ids)
def process_recording_chunks(recording_id):
"""
Process a recording by creating chunks and generating embeddings.
This should be called after a recording is transcribed.
"""
try:
recording = db.session.get(Recording, recording_id)
if not recording or not recording.transcription:
return False
# Delete existing chunks for this recording
TranscriptChunk.query.filter_by(recording_id=recording_id).delete()
# Create chunks
chunks = chunk_transcription(recording.transcription)
if not chunks:
return True
# Generate embeddings
embeddings = generate_embeddings(chunks)
# Store chunks in database
for i, (chunk_text, embedding) in enumerate(zip(chunks, embeddings)):
chunk = TranscriptChunk(
recording_id=recording_id,
user_id=recording.user_id,
chunk_index=i,
content=chunk_text,
embedding=serialize_embedding(embedding) if embedding is not None else None
)
db.session.add(chunk)
db.session.commit()
current_app.logger.info(f"Created {len(chunks)} chunks for recording {recording_id}")
return True
except Exception as e:
current_app.logger.error(f"Error processing chunks for recording {recording_id}: {e}")
db.session.rollback()
return False
def basic_text_search_chunks(user_id, query, filters=None, top_k=5):
"""
Basic text search fallback when embeddings are not available.
Uses simple text matching instead of semantic search.
Searches across user's own recordings and recordings shared with them.
"""
try:
# Get all accessible recording IDs (own + shared)
accessible_recording_ids = get_accessible_recording_ids(user_id)
if not accessible_recording_ids:
return []
# Build base query for chunks from accessible recordings with eager loading
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
TranscriptChunk.recording_id.in_(accessible_recording_ids)
)
# Apply filters if provided
if filters:
if filters.get('tag_ids'):
chunks_query = chunks_query.join(Recording).join(
RecordingTag, Recording.id == RecordingTag.recording_id
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
if filters.get('speaker_names'):
# Filter by participants field in recordings instead of chunk speaker_name
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
chunks_query = chunks_query.join(Recording)
# Build OR conditions for each speaker name in participants
speaker_conditions = []
for speaker_name in filters['speaker_names']:
speaker_conditions.append(
Recording.participants.ilike(f'%{speaker_name}%')
)
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
if filters.get('recording_ids'):
chunks_query = chunks_query.filter(
TranscriptChunk.recording_id.in_(filters['recording_ids'])
)
if filters.get('date_from') or filters.get('date_to'):
chunks_query = chunks_query.join(Recording)
if filters.get('date_from'):
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
if filters.get('date_to'):
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
# Text search - filter stop words and rank by match count
stop_words = {'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
'would', 'could', 'should', 'may', 'might', 'shall', 'can',
'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
'up', 'about', 'into', 'through', 'during', 'before', 'after',
'and', 'but', 'or', 'nor', 'not', 'so', 'yet', 'both',
'it', 'its', 'this', 'that', 'these', 'those', 'what', 'which',
'who', 'whom', 'how', 'when', 'where', 'why',
'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he', 'she',
'his', 'her', 'they', 'them', 'their'}
query_words = [w for w in query.lower().split() if w not in stop_words and len(w) > 1]
if not query_words:
# If all words were stop words, fall back to using original query words
query_words = [w for w in query.lower().split() if len(w) > 1]
if query_words:
from sqlalchemy import or_, func, case, literal
# Filter: match ANY keyword (OR) to get candidates
text_conditions = []
for word in query_words:
text_conditions.append(TranscriptChunk.content.ilike(f'%{word}%'))
chunks_query = chunks_query.filter(or_(*text_conditions))
# Fetch more candidates than needed so we can rank them
chunks = chunks_query.limit(top_k * 5).all()
# Rank by how many query words each chunk matches
scored_chunks = []
for chunk in chunks:
content_lower = chunk.content.lower()
match_count = sum(1 for word in query_words if word in content_lower)
score = match_count / len(query_words) # 0.0 to 1.0
scored_chunks.append((chunk, score))
# Sort by score descending, take top_k
scored_chunks.sort(key=lambda x: x[1], reverse=True)
return scored_chunks[:top_k]
# No usable query words
return []
except Exception as e:
current_app.logger.error(f"Error in basic text search: {e}")
return []
def semantic_search_chunks(user_id, query, filters=None, top_k=5):
"""
Perform semantic search on transcript chunks with filtering.
Searches across user's own recordings and recordings shared with them.
Args:
user_id (int): User ID for permission filtering
query (str): Search query
filters (dict): Optional filters for tags, speakers, dates, recording_ids
top_k (int): Number of top chunks to return
Returns:
list: List of relevant chunks with similarity scores
"""
try:
# If embeddings are not available, fall back to basic text search
if not EMBEDDINGS_AVAILABLE:
current_app.logger.info("Embeddings not available - using basic text search as fallback")
return basic_text_search_chunks(user_id, query, filters, top_k)
# Generate embedding for the query
model = get_embedding_model()
if not model:
return basic_text_search_chunks(user_id, query, filters, top_k)
query_embedding = model.encode([query])[0]
# Get all accessible recording IDs (own + shared)
accessible_recording_ids = get_accessible_recording_ids(user_id)
if not accessible_recording_ids:
return []
# Build base query for chunks from accessible recordings with eager loading
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
TranscriptChunk.recording_id.in_(accessible_recording_ids)
)
# Apply filters if provided
if filters:
if filters.get('tag_ids'):
# Join with recordings that have specified tags
chunks_query = chunks_query.join(Recording).join(
RecordingTag, Recording.id == RecordingTag.recording_id
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
if filters.get('speaker_names'):
# Filter by participants field in recordings instead of chunk speaker_name
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
chunks_query = chunks_query.join(Recording)
# Build OR conditions for each speaker name in participants
speaker_conditions = []
for speaker_name in filters['speaker_names']:
speaker_conditions.append(
Recording.participants.ilike(f'%{speaker_name}%')
)
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
if filters.get('recording_ids'):
chunks_query = chunks_query.filter(
TranscriptChunk.recording_id.in_(filters['recording_ids'])
)
if filters.get('date_from') or filters.get('date_to'):
chunks_query = chunks_query.join(Recording)
if filters.get('date_from'):
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
if filters.get('date_to'):
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
# Get chunks that have embeddings
chunks = chunks_query.filter(TranscriptChunk.embedding.isnot(None)).all()
if not chunks:
return []
# Calculate similarities
chunk_similarities = []
for chunk in chunks:
try:
chunk_embedding = deserialize_embedding(chunk.embedding)
if chunk_embedding is not None:
similarity = cosine_similarity(
query_embedding.reshape(1, -1),
chunk_embedding.reshape(1, -1)
)[0][0]
chunk_similarities.append((chunk, float(similarity)))
except Exception as e:
current_app.logger.warning(f"Error calculating similarity for chunk {chunk.id}: {e}")
continue
# Sort by similarity and return top k
chunk_similarities.sort(key=lambda x: x[1], reverse=True)
return chunk_similarities[:top_k]
except Exception as e:
current_app.logger.error(f"Error in semantic search: {e}")
return []
# --- Helper Functions for Document Processing ---

631
src/services/job_queue.py Normal file
View File

@@ -0,0 +1,631 @@
"""
Fair database-backed job queue for background processing tasks.
This queue ensures:
- Jobs persist across application restarts
- Fair round-robin scheduling between users
- Separate queues for transcription (slow) and summary (fast) jobs
- Limited concurrency to prevent overwhelming external services
- Automatic recovery of orphaned jobs
"""
import os
import json
import threading
import time
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List
from contextlib import contextmanager
logger = logging.getLogger(__name__)
# Configuration
TRANSCRIPTION_WORKERS = int(os.environ.get('JOB_QUEUE_WORKERS', '2'))
SUMMARY_WORKERS = int(os.environ.get('SUMMARY_QUEUE_WORKERS', '2'))
MAX_RETRIES = int(os.environ.get('JOB_MAX_RETRIES', '3'))
POLL_INTERVAL = 1.0 # seconds between checking for new jobs
# Job type categories
TRANSCRIPTION_JOBS = ['transcribe', 'reprocess_transcription']
SUMMARY_JOBS = ['summarize', 'reprocess_summary']
class FairJobQueue:
"""
A database-backed job queue with fair scheduling across users.
Uses separate queues for transcription and summary jobs to prevent
slow transcription jobs from blocking fast summary jobs.
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""Singleton pattern to ensure only one queue exists."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Initialize the job queue."""
if self._initialized:
return
self._transcription_workers = []
self._summary_workers = []
self._running = False
self._app = None
# Separate round-robin tracking for each queue
self._last_user_id_transcription = None
self._last_user_id_summary = None
# Lock for claiming jobs (SQLite doesn't support row-level locking)
self._claim_lock = threading.Lock()
self._initialized = True
logger.info(f"FairJobQueue initialized: {TRANSCRIPTION_WORKERS} transcription workers, {SUMMARY_WORKERS} summary workers")
def init_app(self, app):
"""Initialize with Flask app for context management."""
self._app = app
@contextmanager
def _app_context(self):
"""Get application context for database operations."""
if self._app:
with self._app.app_context():
yield
else:
yield
def start(self):
"""Start the worker threads for both queues."""
if self._running:
return
self._running = True
# Start transcription workers
for i in range(TRANSCRIPTION_WORKERS):
worker = threading.Thread(
target=self._worker_loop,
args=(TRANSCRIPTION_JOBS, 'transcription'),
name=f"TranscriptionWorker-{i}",
daemon=True
)
worker.start()
self._transcription_workers.append(worker)
# Start summary workers
for i in range(SUMMARY_WORKERS):
worker = threading.Thread(
target=self._worker_loop,
args=(SUMMARY_JOBS, 'summary'),
name=f"SummaryWorker-{i}",
daemon=True
)
worker.start()
self._summary_workers.append(worker)
logger.info(f"Started {TRANSCRIPTION_WORKERS} transcription workers and {SUMMARY_WORKERS} summary workers")
def stop(self):
"""Stop the worker threads gracefully."""
self._running = False
for worker in self._transcription_workers + self._summary_workers:
worker.join(timeout=5)
self._transcription_workers.clear()
self._summary_workers.clear()
logger.info("Job queue workers stopped")
def _worker_loop(self, job_types: List[str], queue_name: str):
"""Main worker loop that processes jobs of specific types."""
while self._running:
try:
job = self._claim_next_job(job_types, queue_name)
if job:
self._process_job(job)
else:
# No jobs available, sleep briefly
time.sleep(POLL_INTERVAL)
except Exception as e:
logger.error(f"{queue_name.capitalize()} worker error: {e}", exc_info=True)
time.sleep(POLL_INTERVAL)
def _claim_next_job(self, job_types: List[str], queue_name: str):
"""
Claim the next job of specified types using fair round-robin scheduling.
Args:
job_types: List of job types this worker handles
queue_name: Name of the queue ('transcription' or 'summary')
Returns the claimed job or None if no jobs available.
"""
# Use lock to prevent race conditions (SQLite doesn't support row-level locking)
with self._claim_lock:
with self._app_context():
from src.database import db
from src.models import ProcessingJob
try:
# Get list of users with queued jobs of our types
users_with_jobs = db.session.query(
ProcessingJob.user_id
).filter(
ProcessingJob.status == 'queued',
ProcessingJob.job_type.in_(job_types)
).group_by(
ProcessingJob.user_id
).order_by(
db.func.min(ProcessingJob.created_at)
).all()
if not users_with_jobs:
return None
user_ids = [u[0] for u in users_with_jobs]
# Get last user ID for this queue type
last_user_id = (self._last_user_id_transcription
if queue_name == 'transcription'
else self._last_user_id_summary)
# Round-robin: pick next user after last processed
next_user_id = None
if last_user_id is not None and last_user_id in user_ids:
idx = user_ids.index(last_user_id)
next_user_id = user_ids[(idx + 1) % len(user_ids)]
else:
next_user_id = user_ids[0]
# Get oldest queued job of our types for this user
candidate_job = ProcessingJob.query.filter(
ProcessingJob.user_id == next_user_id,
ProcessingJob.status == 'queued',
ProcessingJob.job_type.in_(job_types)
).order_by(
ProcessingJob.created_at
).first()
if candidate_job:
# Atomically claim the job - only succeeds if status is still 'queued'
# This prevents race conditions when multiple workers try to claim the same job
from sqlalchemy import update
claim_time = datetime.utcnow()
result = db.session.execute(
update(ProcessingJob)
.where(
ProcessingJob.id == candidate_job.id,
ProcessingJob.status == 'queued' # Critical: only claim if still queued
)
.values(status='processing', started_at=claim_time)
)
if result.rowcount == 0:
# Job was already claimed by another worker - this is expected with multiple workers
logger.debug(f"[{queue_name.upper()}] Job {candidate_job.id} already claimed by another worker")
db.session.rollback()
return None
# Also update Recording.status to reflect active processing
from src.models import Recording
recording = db.session.get(Recording, candidate_job.recording_id)
if recording and recording.status == 'QUEUED':
recording.status = 'PROCESSING'
db.session.commit()
# Refresh the job object to get updated values
db.session.refresh(candidate_job)
# Update last user ID for this queue
if queue_name == 'transcription':
self._last_user_id_transcription = next_user_id
else:
self._last_user_id_summary = next_user_id
wait_time = (claim_time - candidate_job.created_at).total_seconds()
logger.info(f"[{queue_name.upper()}] Claimed job {candidate_job.id} (type={candidate_job.job_type}) for user {candidate_job.user_id}, recording {candidate_job.recording_id} (waited {wait_time:.1f}s)")
return candidate_job
return None
except Exception as e:
logger.error(f"Error claiming {queue_name} job: {e}", exc_info=True)
db.session.rollback()
return None
def _is_permanent_error(self, error_str: str) -> bool:
"""
Detect if an error is permanent and should not be retried.
Permanent errors include:
- 400: Bad request (invalid format, invalid parameters)
- 413: File too large (user needs to enable chunking or compress file)
- 401/403: Authentication/authorization errors (credentials issue)
- 402: Payment required (billing issue)
- 404: Resource not found (model doesn't exist)
- Invalid format errors (file needs to be converted)
"""
error_lower = error_str.lower()
# HTTP status codes that indicate permanent errors
permanent_codes = ['400', '413', '401', '402', '403', '404']
for code in permanent_codes:
if f'error code: {code}' in error_lower or f'status {code}' in error_lower:
return True
# Specific error patterns that are permanent (simple substring matching)
permanent_patterns = [
'maximum content size limit',
'file too large',
'payload too large',
'invalid api key',
'incorrect api key',
'authentication failed',
'unauthorized',
'permission denied',
'access denied',
'billing',
'payment required',
'quota exceeded',
'insufficient funds',
'model not found',
'invalid model',
'unsupported format',
'invalid file format',
'invalid_request_error',
'bad request',
]
for pattern in permanent_patterns:
if pattern in error_lower:
return True
return False
def _process_job(self, job):
"""Process a single job by dispatching to the appropriate task function."""
job_id = job.id
job_type = job.job_type
recording_id = job.recording_id
params_str = job.params
is_new_upload = job.is_new_upload
with self._app_context():
from src.database import db
from src.models import ProcessingJob, Recording
from flask import current_app
try:
# Parse job parameters
params = json.loads(params_str) if params_str else {}
# Re-fetch the job in this session context to ensure it's attached
job = db.session.get(ProcessingJob, job_id)
if not job:
logger.error(f"Job {job_id} not found when trying to process")
return
# Get recording
recording = db.session.get(Recording, recording_id)
if not recording:
raise ValueError(f"Recording {recording_id} not found")
# Dispatch based on job type
if job_type == 'transcribe':
self._run_transcription(job, recording, params)
elif job_type == 'summarize':
self._run_summarization(job, recording, params)
elif job_type == 'reprocess_transcription':
self._run_reprocess_transcription(job, recording, params)
elif job_type == 'reprocess_summary':
self._run_reprocess_summary(job, recording, params)
else:
raise ValueError(f"Unknown job type: {job_type}")
# Mark as completed - re-fetch to ensure we have latest state
job = db.session.get(ProcessingJob, job_id)
if job:
job.status = 'completed'
job.completed_at = datetime.utcnow()
db.session.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
error_str = str(e)
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
# Check if this is a permanent error that shouldn't be retried
is_permanent_error = self._is_permanent_error(error_str)
# Re-fetch job to update it
job = db.session.get(ProcessingJob, job_id)
if job:
job.error_message = error_str
job.retry_count += 1
# Only retry if: not a permanent error AND under retry limit
if not is_permanent_error and job.retry_count < MAX_RETRIES:
# Re-queue for retry
job.status = 'queued'
job.started_at = None
logger.info(f"Job {job_id} re-queued for retry ({job.retry_count}/{MAX_RETRIES})")
else:
job.status = 'failed'
job.completed_at = datetime.utcnow()
recording = db.session.get(Recording, recording_id)
if is_permanent_error:
logger.info(f"Job {job_id} failed with permanent error (no retry): {error_str[:100]}")
# Always keep recordings with FAILED status so users can see the error
# and reprocess later (e.g., when ASR server recovers)
if recording:
# Keep the recording with FAILED status so user can see the error and fix settings
recording.status = 'FAILED'
# Format the error for nice display
from src.utils.error_formatting import format_error_for_storage
recording.transcription = format_error_for_storage(error_str)
if is_permanent_error:
logger.error(f"Job {job_id} failed permanently (non-retryable error)")
else:
logger.error(f"Job {job_id} failed permanently after {MAX_RETRIES} retries")
db.session.commit()
def _run_transcription(self, job, recording, params):
"""Run transcription task. Status updates handled by task function."""
from src.tasks.processing import transcribe_audio_task
from flask import current_app
filepath = recording.audio_path
filename_for_asr = recording.original_filename or os.path.basename(filepath)
transcribe_audio_task(
current_app._get_current_object().app_context(),
recording.id,
filepath,
filename_for_asr,
datetime.utcnow(),
language=params.get('language'),
min_speakers=params.get('min_speakers'),
max_speakers=params.get('max_speakers'),
tag_id=params.get('tag_id'),
hotwords=params.get('hotwords'),
initial_prompt=params.get('initial_prompt'),
)
def _run_summarization(self, job, recording, params):
"""Run summarization-only task. Status updates handled by task function."""
from src.tasks.processing import generate_summary_only_task
from flask import current_app
generate_summary_only_task(
current_app._get_current_object().app_context(),
recording.id,
custom_prompt_override=params.get('custom_prompt'),
user_id=params.get('user_id')
)
def _run_reprocess_transcription(self, job, recording, params):
"""Run transcription reprocessing task. Status updates handled by task function."""
from src.tasks.processing import transcribe_audio_task
from flask import current_app
filepath = recording.audio_path
filename_for_asr = recording.original_filename or os.path.basename(filepath)
transcribe_audio_task(
current_app._get_current_object().app_context(),
recording.id,
filepath,
filename_for_asr,
datetime.utcnow(),
language=params.get('language'),
min_speakers=params.get('min_speakers'),
max_speakers=params.get('max_speakers'),
tag_id=params.get('tag_id'),
hotwords=params.get('hotwords'),
initial_prompt=params.get('initial_prompt'),
)
def _run_reprocess_summary(self, job, recording, params):
"""Run summary reprocessing task. Status updates handled by task function."""
from src.tasks.processing import generate_summary_only_task
from flask import current_app
generate_summary_only_task(
current_app._get_current_object().app_context(),
recording.id,
custom_prompt_override=params.get('custom_prompt'),
user_id=params.get('user_id')
)
def enqueue(
self,
user_id: int,
recording_id: int,
job_type: str,
params: Dict[str, Any] = None,
is_new_upload: bool = False
) -> int:
"""
Add a job to the database queue.
Args:
user_id: ID of the user who owns this job
recording_id: ID of the recording to process
job_type: Type of job (transcribe, summarize, reprocess_transcription, reprocess_summary)
params: Optional parameters for the job
is_new_upload: True if this is a new file upload (for cleanup on failure)
Returns:
The created job ID
"""
with self._app_context():
from src.database import db
from src.models import ProcessingJob, Recording
# Check for existing active job of the SAME TYPE for this recording
# Allow different job types to coexist (e.g., transcribe and summarize)
existing = ProcessingJob.query.filter(
ProcessingJob.recording_id == recording_id,
ProcessingJob.job_type == job_type,
ProcessingJob.status.in_(['queued', 'processing'])
).first()
if existing:
logger.warning(f"Job of type {job_type} already exists for recording {recording_id}: {existing.id}")
return existing.id
# Create new job
job = ProcessingJob(
user_id=user_id,
recording_id=recording_id,
job_type=job_type,
params=json.dumps(params) if params else None,
is_new_upload=is_new_upload
)
db.session.add(job)
# Update recording status based on job type
recording = db.session.get(Recording, recording_id)
if recording:
if job_type in SUMMARY_JOBS:
recording.status = 'SUMMARIZING'
else:
recording.status = 'QUEUED'
db.session.commit()
# Auto-start workers if not running
if not self._running:
self.start()
queue_name = 'summary' if job_type in SUMMARY_JOBS else 'transcription'
logger.info(f"Enqueued {queue_name} job {job.id} (type={job_type}) for user {user_id}, recording {recording_id}")
return job.id
def recover_orphaned_jobs(self):
"""
Recover jobs that were processing when the app crashed.
Call this on startup to reset orphaned jobs back to queued.
"""
with self._app_context():
from src.database import db
from src.models import ProcessingJob
orphaned = ProcessingJob.query.filter(
ProcessingJob.status == 'processing'
).all()
for job in orphaned:
job.status = 'queued'
job.started_at = None
queue_name = 'summary' if job.job_type in SUMMARY_JOBS else 'transcription'
logger.info(f"Recovered orphaned {queue_name} job {job.id} for recording {job.recording_id}")
if orphaned:
db.session.commit()
logger.info(f"Recovered {len(orphaned)} orphaned jobs")
def get_queue_status(self) -> Dict[str, Any]:
"""Get the current queue status for both queues."""
with self._app_context():
from src.models import ProcessingJob
transcription_queued = ProcessingJob.query.filter(
ProcessingJob.status == 'queued',
ProcessingJob.job_type.in_(TRANSCRIPTION_JOBS)
).count()
transcription_processing = ProcessingJob.query.filter(
ProcessingJob.status == 'processing',
ProcessingJob.job_type.in_(TRANSCRIPTION_JOBS)
).count()
summary_queued = ProcessingJob.query.filter(
ProcessingJob.status == 'queued',
ProcessingJob.job_type.in_(SUMMARY_JOBS)
).count()
summary_processing = ProcessingJob.query.filter(
ProcessingJob.status == 'processing',
ProcessingJob.job_type.in_(SUMMARY_JOBS)
).count()
return {
"transcription_queue": {
"queued": transcription_queued,
"processing": transcription_processing,
"workers": TRANSCRIPTION_WORKERS
},
"summary_queue": {
"queued": summary_queued,
"processing": summary_processing,
"workers": SUMMARY_WORKERS
},
"is_running": self._running
}
def get_position_in_queue(self, recording_id: int) -> Optional[int]:
"""Get the position of a recording's job in its respective queue (1-indexed)."""
with self._app_context():
from src.models import ProcessingJob
job = ProcessingJob.query.filter(
ProcessingJob.recording_id == recording_id,
ProcessingJob.status == 'queued'
).first()
if not job:
return None
# Determine which queue this job is in
job_types = SUMMARY_JOBS if job.job_type in SUMMARY_JOBS else TRANSCRIPTION_JOBS
# Count jobs of the same type created before this one
position = ProcessingJob.query.filter(
ProcessingJob.status == 'queued',
ProcessingJob.job_type.in_(job_types),
ProcessingJob.created_at < job.created_at
).count() + 1
return position
def get_job_for_recording(self, recording_id: int):
"""Get the active job for a recording."""
with self._app_context():
from src.models import ProcessingJob
return ProcessingJob.query.filter(
ProcessingJob.recording_id == recording_id,
ProcessingJob.status.in_(['queued', 'processing'])
).first()
def cleanup_old_jobs(self, max_age_hours: int = 24):
"""Remove completed/failed jobs older than max_age_hours."""
with self._app_context():
from src.database import db
from src.models import ProcessingJob
from datetime import timedelta
cutoff = datetime.utcnow() - timedelta(hours=max_age_hours)
deleted = ProcessingJob.query.filter(
ProcessingJob.status.in_(['completed', 'failed']),
ProcessingJob.completed_at < cutoff
).delete(synchronize_session=False)
if deleted:
db.session.commit()
logger.info(f"Cleaned up {deleted} old jobs")
# Global job queue instance
job_queue = FairJobQueue()

498
src/services/llm.py Normal file
View File

@@ -0,0 +1,498 @@
"""
LLM API integration services (OpenAI/OpenRouter).
"""
import os
import re
import json
import logging
import httpx
from openai import OpenAI
# Use standard logging instead of current_app.logger for context independence
logger = logging.getLogger(__name__)
class TokenBudgetExceeded(Exception):
"""Raised when user exceeds their token budget."""
def __init__(self, message, usage_percentage=100):
self.message = message
self.usage_percentage = usage_percentage
super().__init__(message)
from src.utils import safe_json_loads, extract_json_object
# Configuration - use TEXT_MODEL_* variables for LLM
TEXT_MODEL_API_KEY = os.environ.get("TEXT_MODEL_API_KEY")
TEXT_MODEL_BASE_URL = os.environ.get("TEXT_MODEL_BASE_URL", "https://openrouter.ai/api/v1")
if TEXT_MODEL_BASE_URL:
TEXT_MODEL_BASE_URL = TEXT_MODEL_BASE_URL.split('#')[0].strip()
TEXT_MODEL_NAME = os.environ.get("TEXT_MODEL_NAME", "openai/gpt-3.5-turbo")
# Chat model configuration (optional - falls back to TEXT_MODEL_* if not set)
CHAT_MODEL_API_KEY = os.environ.get("CHAT_MODEL_API_KEY")
CHAT_MODEL_BASE_URL = os.environ.get("CHAT_MODEL_BASE_URL")
if CHAT_MODEL_BASE_URL:
CHAT_MODEL_BASE_URL = CHAT_MODEL_BASE_URL.split('#')[0].strip()
CHAT_MODEL_NAME = os.environ.get("CHAT_MODEL_NAME")
# Chat-specific GPT-5 settings (optional - falls back to main GPT5_* settings)
CHAT_GPT5_REASONING_EFFORT = os.environ.get("CHAT_GPT5_REASONING_EFFORT")
CHAT_GPT5_VERBOSITY = os.environ.get("CHAT_GPT5_VERBOSITY")
# Streaming options - disable for LLM servers that don't support OpenAI's stream_options
ENABLE_STREAM_OPTIONS = os.environ.get("ENABLE_STREAM_OPTIONS", "true").lower() == "true"
def get_chat_config():
"""
Get chat model configuration, falling back to TEXT_MODEL if not set.
Returns a dict with api_key, base_url, model_name, and GPT-5 settings.
"""
if CHAT_MODEL_API_KEY and CHAT_MODEL_NAME:
return {
'api_key': CHAT_MODEL_API_KEY,
'base_url': CHAT_MODEL_BASE_URL or TEXT_MODEL_BASE_URL,
'model_name': CHAT_MODEL_NAME,
'gpt5_reasoning_effort': CHAT_GPT5_REASONING_EFFORT or os.environ.get("GPT5_REASONING_EFFORT", "medium"),
'gpt5_verbosity': CHAT_GPT5_VERBOSITY or os.environ.get("GPT5_VERBOSITY", "medium")
}
return {
'api_key': TEXT_MODEL_API_KEY,
'base_url': TEXT_MODEL_BASE_URL,
'model_name': TEXT_MODEL_NAME,
'gpt5_reasoning_effort': os.environ.get("GPT5_REASONING_EFFORT", "medium"),
'gpt5_verbosity': os.environ.get("GPT5_VERBOSITY", "medium")
}
# Set up HTTP client with custom headers for OpenRouter app identification
app_headers = {
"HTTP-Referer": "https://github.com/murtaza-nasir/speakr",
"X-Title": "Speakr - AI Audio Transcription",
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
}
http_client_no_proxy = httpx.Client(
verify=True,
headers=app_headers
)
# Create client with placeholder key if not provided (allows app to start)
try:
api_key = TEXT_MODEL_API_KEY or "not-needed"
client = OpenAI(
api_key=api_key,
base_url=TEXT_MODEL_BASE_URL,
http_client=http_client_no_proxy
)
except Exception as client_init_e:
client = None
# Create chat client (may be same as main client if no separate config)
chat_client = None
try:
chat_config = get_chat_config()
if chat_config['api_key']:
if CHAT_MODEL_API_KEY and CHAT_MODEL_API_KEY != TEXT_MODEL_API_KEY:
# Separate chat configuration - create dedicated client
chat_client = OpenAI(
api_key=chat_config['api_key'],
base_url=chat_config['base_url'],
http_client=http_client_no_proxy
)
logger.info(f"Separate chat client initialized: {chat_config['base_url']} / {chat_config['model_name']}")
else:
# Use same client as main LLM
chat_client = client
except Exception as chat_client_init_e:
logger.warning(f"Failed to initialize chat client, falling back to main client: {chat_client_init_e}")
chat_client = client
def is_gpt5_model(model_name):
"""
Check if the model is a GPT-5 series model that requires special API parameters.
Args:
model_name: The model name string
Returns:
Boolean indicating if this is a GPT-5 model
"""
if not model_name:
return False
model_lower = model_name.lower()
return model_lower.startswith('gpt-5') or model_lower in ['gpt-5', 'gpt-5-mini', 'gpt-5-nano', 'gpt-5-chat-latest']
def is_using_openai_api():
"""
Check if we're using the official OpenAI API (not OpenRouter or other providers).
Returns:
Boolean indicating if this is the OpenAI API
"""
return TEXT_MODEL_BASE_URL and 'api.openai.com' in TEXT_MODEL_BASE_URL
def call_llm_completion(messages, temperature=0.7, response_format=None, stream=False, max_tokens=None,
user_id=None, operation_type=None):
"""
Centralized function for LLM API calls with proper error handling and logging.
Args:
messages: List of message dicts with 'role' and 'content'
temperature: Sampling temperature (0-1) - ignored for GPT-5 models
response_format: Optional response format dict (e.g., {"type": "json_object"})
stream: Whether to stream the response
max_tokens: Optional maximum tokens to generate
user_id: Optional user ID for token tracking and budget enforcement
operation_type: Optional operation type for token tracking (e.g., 'summarization', 'chat')
Returns:
OpenAI completion object or generator (if streaming)
"""
if not client:
raise ValueError("LLM client not initialized")
if not TEXT_MODEL_API_KEY:
raise ValueError("TEXT_MODEL_API_KEY not configured")
# Check budget before making the call
if user_id and operation_type:
try:
from src.services.token_tracking import token_tracker
can_proceed, usage_pct, msg = token_tracker.check_budget(user_id)
if not can_proceed:
raise TokenBudgetExceeded(msg, usage_pct)
if usage_pct >= 80:
logger.warning(f"User {user_id} at {usage_pct:.1f}% of token budget")
except TokenBudgetExceeded:
raise
except Exception as e:
# Log but don't block on budget check errors
logger.warning(f"Budget check failed for user {user_id}: {e}")
try:
# Check if we're using GPT-5 with OpenAI API
using_gpt5 = is_gpt5_model(TEXT_MODEL_NAME) and is_using_openai_api()
completion_args = {
"model": TEXT_MODEL_NAME,
"messages": messages,
"stream": stream
}
# Add stream_options to get usage in final chunk for streaming
# Some LLM servers don't support this OpenAI-specific option
if stream and ENABLE_STREAM_OPTIONS:
completion_args["stream_options"] = {"include_usage": True}
if using_gpt5:
# GPT-5 models don't support temperature, top_p, or logprobs
# They use reasoning_effort and verbosity instead
logger.debug(f"Using GPT-5 model: {TEXT_MODEL_NAME} - applying GPT-5 specific parameters")
# Get GPT-5 specific parameters from environment variables
reasoning_effort = os.environ.get("GPT5_REASONING_EFFORT", "medium") # minimal, low, medium, high
verbosity = os.environ.get("GPT5_VERBOSITY", "medium") # low, medium, high
# Add GPT-5 specific parameters
completion_args["reasoning_effort"] = reasoning_effort
completion_args["verbosity"] = verbosity
# Use max_completion_tokens instead of max_tokens for GPT-5
if max_tokens:
completion_args["max_completion_tokens"] = max_tokens
else:
# Non-GPT-5 models use standard parameters
completion_args["temperature"] = temperature
if max_tokens:
completion_args["max_tokens"] = max_tokens
if response_format:
completion_args["response_format"] = response_format
response = client.chat.completions.create(**completion_args)
# Track usage for non-streaming calls
if user_id and operation_type and not stream and response.usage:
try:
from src.services.token_tracking import token_tracker
token_tracker.record_usage(
user_id=user_id,
operation_type=operation_type,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model_name=TEXT_MODEL_NAME,
cost=getattr(response.usage, 'cost', None)
)
except Exception as e:
logger.warning(f"Failed to record token usage: {e}")
# Debug log for empty responses
if not stream and response.choices:
content = response.choices[0].message.content
if not content:
logger.warning(f"LLM returned empty content. Model: {TEXT_MODEL_NAME}, finish_reason: {response.choices[0].finish_reason}")
# Log more details if available
if hasattr(response.choices[0].message, 'refusal'):
logger.warning(f"Refusal: {response.choices[0].message.refusal}")
if hasattr(response.choices[0].message, 'tool_calls') and response.choices[0].message.tool_calls:
logger.warning(f"Tool calls present: {response.choices[0].message.tool_calls}")
return response
except TokenBudgetExceeded:
raise
except Exception as e:
logger.error(f"LLM API call failed: {e}")
raise
def call_chat_completion(messages, temperature=0.7, response_format=None, stream=False, max_tokens=None,
user_id=None, operation_type=None):
"""
Chat-specific LLM completion function. Uses dedicated chat model if configured,
otherwise falls back to standard TEXT_MODEL configuration.
Args:
messages: List of message dicts with 'role' and 'content'
temperature: Sampling temperature (0-1) - ignored for GPT-5 models
response_format: Optional response format dict (e.g., {"type": "json_object"})
stream: Whether to stream the response
max_tokens: Optional maximum tokens to generate
user_id: Optional user ID for token tracking and budget enforcement
operation_type: Optional operation type for token tracking (e.g., 'chat')
Returns:
OpenAI completion object or generator (if streaming)
"""
effective_client = chat_client if chat_client else client
chat_config = get_chat_config()
if not effective_client:
raise ValueError("Chat LLM client not initialized")
if not chat_config['api_key']:
raise ValueError("Chat model API key not configured")
# Check budget before making the call
if user_id and operation_type:
try:
from src.services.token_tracking import token_tracker
can_proceed, usage_pct, msg = token_tracker.check_budget(user_id)
if not can_proceed:
raise TokenBudgetExceeded(msg, usage_pct)
if usage_pct >= 80:
logger.warning(f"User {user_id} at {usage_pct:.1f}% of token budget")
except TokenBudgetExceeded:
raise
except Exception as e:
# Log but don't block on budget check errors
logger.warning(f"Budget check failed for user {user_id}: {e}")
try:
model_name = chat_config['model_name']
base_url = chat_config['base_url'] or ''
# Check if we're using GPT-5 with OpenAI API
using_gpt5 = is_gpt5_model(model_name) and 'api.openai.com' in base_url
completion_args = {
"model": model_name,
"messages": messages,
"stream": stream
}
# Add stream_options to get usage in final chunk for streaming
# Some LLM servers don't support this OpenAI-specific option
if stream and ENABLE_STREAM_OPTIONS:
completion_args["stream_options"] = {"include_usage": True}
if using_gpt5:
logger.debug(f"Using GPT-5 chat model: {model_name}")
# Use chat-specific GPT-5 settings from config
completion_args["reasoning_effort"] = chat_config['gpt5_reasoning_effort']
completion_args["verbosity"] = chat_config['gpt5_verbosity']
if max_tokens:
completion_args["max_completion_tokens"] = max_tokens
else:
completion_args["temperature"] = temperature
if max_tokens:
completion_args["max_tokens"] = max_tokens
if response_format:
completion_args["response_format"] = response_format
response = effective_client.chat.completions.create(**completion_args)
# Track usage for non-streaming calls
if user_id and operation_type and not stream and response.usage:
try:
from src.services.token_tracking import token_tracker
token_tracker.record_usage(
user_id=user_id,
operation_type=operation_type,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model_name=model_name,
cost=getattr(response.usage, 'cost', None)
)
except Exception as e:
logger.warning(f"Failed to record token usage: {e}")
# Debug log for empty responses
if not stream and response.choices:
content = response.choices[0].message.content
if not content:
logger.warning(f"Chat LLM returned empty content. Model: {model_name}, finish_reason: {response.choices[0].finish_reason}")
return response
except TokenBudgetExceeded:
raise
except Exception as e:
logger.error(f"Chat LLM API call failed: {e}")
raise
def format_api_error_message(error_str):
"""
Formats API error messages to be more user-friendly.
Specifically handles token limit errors with helpful suggestions.
"""
error_lower = error_str.lower()
# Check for token limit errors
if 'maximum context length' in error_lower and 'tokens' in error_lower:
return "[Summary generation failed: The transcription is too long for AI processing. Request your admin to try using a different LLM with a larger context size, or set a limit for the transcript_length_limit in the system settings.]"
# Check for other common API errors
if 'rate limit' in error_lower:
return "[Summary generation failed: API rate limit exceeded. Please try again in a few minutes.]"
if 'insufficient funds' in error_lower or 'quota exceeded' in error_lower:
return "[Summary generation failed: API quota exceeded. Please contact support.]"
if 'timeout' in error_lower:
return "[Summary generation failed: Request timed out. Please try again.]"
# For other errors, show a generic message
return f"[Summary generation failed: {error_str}]"
def process_streaming_with_thinking(stream, user_id=None, operation_type=None, model_name=None, app=None):
"""
Generator that processes a streaming response and separates thinking content.
Yields SSE-formatted data with 'delta' for regular content and 'thinking' for thinking content.
Args:
stream: The streaming response from the LLM API
user_id: Optional user ID for token tracking
operation_type: Optional operation type for token tracking
model_name: Optional model name for token tracking
app: Optional Flask app instance for database context in generators
"""
content_buffer = ""
in_thinking = False
thinking_buffer = ""
for chunk in stream:
# Check for usage in final chunk (from stream_options={'include_usage': True})
if hasattr(chunk, 'usage') and chunk.usage and user_id and operation_type:
try:
from src.services.token_tracking import token_tracker
# Use app context if provided (needed for generators where context may be lost)
if app:
with app.app_context():
token_tracker.record_usage(
user_id=user_id,
operation_type=operation_type,
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model_name=model_name or TEXT_MODEL_NAME,
cost=getattr(chunk.usage, 'cost', None)
)
else:
token_tracker.record_usage(
user_id=user_id,
operation_type=operation_type,
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model_name=model_name or TEXT_MODEL_NAME,
cost=getattr(chunk.usage, 'cost', None)
)
except Exception as e:
logger.warning(f"Failed to record streaming token usage: {e}")
# Process content delta
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
content_buffer += content
# Process the buffer to detect and handle thinking tags
while True:
if not in_thinking:
# Look for opening thinking tag
think_start = re.search(r'<think(?:ing)?>', content_buffer, re.IGNORECASE)
if think_start:
# Send any content before the thinking tag
before_thinking = content_buffer[:think_start.start()]
if before_thinking:
yield f"data: {json.dumps({'delta': before_thinking})}\n\n"
# Start capturing thinking content
in_thinking = True
content_buffer = content_buffer[think_start.end():]
thinking_buffer = ""
else:
# No thinking tag found, send accumulated content
if content_buffer:
yield f"data: {json.dumps({'delta': content_buffer})}\n\n"
content_buffer = ""
break
else:
# We're inside a thinking tag, look for closing tag
think_end = re.search(r'</think(?:ing)?>', content_buffer, re.IGNORECASE)
if think_end:
# Capture thinking content up to the closing tag
thinking_buffer += content_buffer[:think_end.start()]
# Send the thinking content as a special type
if thinking_buffer.strip():
yield f"data: {json.dumps({'thinking': thinking_buffer.strip()})}\n\n"
# Continue processing after the closing tag
in_thinking = False
content_buffer = content_buffer[think_end.end():]
thinking_buffer = ""
else:
# Still inside thinking tag, accumulate content
thinking_buffer += content_buffer
content_buffer = ""
break
# Handle any remaining content
if in_thinking and thinking_buffer:
# Unclosed thinking tag - send as thinking content
yield f"data: {json.dumps({'thinking': thinking_buffer.strip()})}\n\n"
elif content_buffer:
# Regular content
yield f"data: {json.dumps({'delta': content_buffer})}\n\n"
# Signal the end of the stream
yield f"data: {json.dumps({'end_of_stream': True})}\n\n"

219
src/services/retention.py Normal file
View File

@@ -0,0 +1,219 @@
"""
Recording retention and auto-deletion services.
"""
import os
from datetime import datetime, timedelta
from flask import current_app
from src.database import db
from src.models import Recording, RecordingTag, Tag
ENABLE_AUTO_DELETION = os.environ.get('ENABLE_AUTO_DELETION', 'false').lower() == 'true'
GLOBAL_RETENTION_DAYS = int(os.environ.get('GLOBAL_RETENTION_DAYS', '0'))
DELETION_MODE = os.environ.get('DELETION_MODE', 'full_recording')
def is_recording_exempt_from_deletion(recording):
"""
Check if a recording is exempt from auto-deletion.
Args:
recording: Recording object to check
Returns:
Boolean indicating if the recording should be kept
"""
# Manual exemption flag
if recording.deletion_exempt:
return True
# Check if any of the recording's tags protect it from deletion
# Protection can be indicated by either protect_from_deletion flag OR retention_days == -1
for tag_assoc in recording.tag_associations:
if tag_assoc.tag.protect_from_deletion:
return True
if tag_assoc.tag.retention_days == -1:
return True
return False
def get_retention_days_for_recording(recording):
"""
Get the effective retention period for a recording.
Multi-tier system: tag retention (shortest) → global retention
Tags with retention_days set override the global retention policy.
If multiple tags have retention_days, the SHORTEST period is used (most conservative).
Note: retention_days == -1 indicates infinite retention (protected), which is handled separately.
Args:
recording: Recording object
Returns:
Integer days for retention period, or None if no retention applies
"""
# Collect all tag-level retention periods
# Skip -1 (infinite retention/protected) as that's handled in is_recording_exempt_from_deletion
tag_retention_periods = []
for tag_assoc in recording.tag_associations:
if tag_assoc.tag.retention_days and tag_assoc.tag.retention_days > 0:
tag_retention_periods.append(tag_assoc.tag.retention_days)
# If any tags have retention periods, use the shortest one (most conservative)
if tag_retention_periods:
return min(tag_retention_periods)
# Fall back to global retention
if GLOBAL_RETENTION_DAYS > 0:
return GLOBAL_RETENTION_DAYS
return None
def process_auto_deletion():
"""
Process auto-deletion of recordings based on retention policies.
This can be called by a scheduled job or admin endpoint.
Supports per-recording retention via tag-level retention_days overrides.
Tags with retention_days set take precedence over global retention.
Returns:
Dictionary with deletion statistics
"""
if not ENABLE_AUTO_DELETION:
return {'error': 'Auto-deletion is not enabled'}
# Check if any retention policy exists (global or tag-level)
has_global_retention = GLOBAL_RETENTION_DAYS > 0
# We'll check for tag-level retention on a per-recording basis
if not has_global_retention:
# Still check recordings in case they have tag-level retention
current_app.logger.info("No global retention configured, checking for tag-level retention policies")
stats = {
'checked': 0,
'deleted_audio_only': 0,
'deleted_full': 0,
'exempted': 0,
'skipped_no_retention': 0,
'errors': 0
}
try:
# Get completed recordings to check
# In audio_only mode: Skip recordings where audio was already deleted
# In full_recording mode: Include all (to catch audio-only deletions for full cleanup)
if DELETION_MODE == 'audio_only':
all_recordings = Recording.query.filter(
Recording.status == 'COMPLETED',
Recording.audio_deleted_at.is_(None) # Skip already-deleted audio
).all()
else: # full_recording mode
all_recordings = Recording.query.filter(
Recording.status == 'COMPLETED'
).all()
stats['checked'] = len(all_recordings)
current_time = datetime.utcnow()
for recording in all_recordings:
try:
# Check if exempt from deletion entirely
if is_recording_exempt_from_deletion(recording):
stats['exempted'] += 1
continue
# Get the effective retention period for this specific recording
retention_days = get_retention_days_for_recording(recording)
if not retention_days:
# No retention policy applies to this recording
stats['skipped_no_retention'] += 1
continue
# Calculate the cutoff date for this specific recording
cutoff_date = current_time - timedelta(days=retention_days)
# Check if recording is past its retention period
if recording.created_at >= cutoff_date:
# Recording is still within retention period
continue
# Recording is past retention period - process deletion
# Determine deletion mode
if DELETION_MODE == 'audio_only':
# Delete only the audio file, keep transcription
if recording.audio_path and os.path.exists(recording.audio_path):
current_app.logger.info(f"Recording {recording.id} is past retention ({retention_days} days), deleting audio")
os.remove(recording.audio_path)
current_app.logger.info(f"Auto-deleted audio file: {recording.audio_path}")
recording.audio_deleted_at = datetime.utcnow()
db.session.commit()
stats['deleted_audio_only'] += 1
else:
# Audio already deleted or doesn't exist - just mark timestamp
if not recording.audio_deleted_at:
recording.audio_deleted_at = datetime.utcnow()
db.session.commit()
current_app.logger.debug(f"Recording {recording.id} audio file not found, marked as deleted")
else: # full_recording mode
# Check if this is completing a previous audio_only deletion
if recording.audio_deleted_at:
current_app.logger.info(f"Recording {recording.id} has deleted audio (mode changed), completing full deletion")
else:
current_app.logger.info(f"Recording {recording.id} is past retention ({retention_days} days), deleting fully")
# Delete audio file if it exists
if recording.audio_path and os.path.exists(recording.audio_path):
os.remove(recording.audio_path)
# Delete associated processing jobs (required due to NOT NULL constraint)
from src.models.processing_job import ProcessingJob
ProcessingJob.query.filter_by(recording_id=recording.id).delete()
# Delete the database record (cascades to chunks, shares, etc.)
db.session.delete(recording)
db.session.commit()
stats['deleted_full'] += 1
current_app.logger.info(f"Auto-deleted full recording ID: {recording.id}")
except Exception as e:
stats['errors'] += 1
current_app.logger.error(f"Error auto-deleting recording {recording.id}: {e}")
db.session.rollback()
# After processing recording deletions, clean up orphaned speaker profiles
try:
from src.services.speaker_cleanup import cleanup_orphaned_speakers
speaker_stats = cleanup_orphaned_speakers()
stats['speakers_deleted'] = speaker_stats['speakers_deleted']
stats['embeddings_cleaned'] = speaker_stats['embeddings_removed']
stats['speakers_evaluated'] = speaker_stats['speakers_evaluated']
current_app.logger.info(
f"Speaker cleanup completed: {speaker_stats['speakers_deleted']} speakers deleted, "
f"{speaker_stats['embeddings_removed']} embedding references removed"
)
except Exception as e:
current_app.logger.error(f"Error during speaker cleanup: {e}", exc_info=True)
stats['speaker_cleanup_error'] = str(e)
current_app.logger.info(f"Auto-deletion completed: {stats}")
return stats
except Exception as e:
current_app.logger.error(f"Error during auto-deletion process: {e}", exc_info=True)
return {'error': str(e)}
# --- API client setup for OpenRouter ---
# Use environment variables from .env

217
src/services/speaker.py Normal file
View File

@@ -0,0 +1,217 @@
"""
Speaker identification and management services.
"""
import os
import re
from datetime import datetime
from flask import current_app
from flask_login import current_user
from src.database import db
from src.models import Speaker, SystemSetting
from src.services.llm import call_llm_completion
from src.utils import safe_json_loads
# NOTE: format_transcription_for_llm is referenced but not defined - needs to be implemented
def format_transcription_for_llm(transcription):
"""
Format transcription for LLM processing.
TODO: This function needs proper implementation.
If transcription is JSON, extract and format the text.
Otherwise return as-is.
"""
if isinstance(transcription, str):
try:
import json
data = json.loads(transcription)
# If it's JSON diarized format, extract text
if isinstance(data, list):
return '\n'.join([f"[{seg.get('speaker', 'UNKNOWN')}] {seg.get('text', '')}"
for seg in data if 'text' in seg])
except:
pass
return str(transcription)
# Import TEXT_MODEL_API_KEY from llm service
from src.services.llm import TEXT_MODEL_API_KEY
def update_speaker_usage(speaker_names):
"""Helper function to update speaker usage statistics."""
if not speaker_names or not current_user.is_authenticated:
return
try:
for name in speaker_names:
name = name.strip()
if not name:
continue
speaker = Speaker.query.filter_by(user_id=current_user.id, name=name).first()
if speaker:
speaker.use_count += 1
speaker.last_used = datetime.utcnow()
else:
# Create new speaker
speaker = Speaker(
name=name,
user_id=current_user.id,
use_count=1,
created_at=datetime.utcnow(),
last_used=datetime.utcnow()
)
db.session.add(speaker)
db.session.commit()
except Exception as e:
current_app.logger.error(f"Error updating speaker usage: {e}")
db.session.rollback()
def identify_speakers_from_text(transcription):
"""
Uses an LLM to identify speakers from a transcription.
"""
if not TEXT_MODEL_API_KEY:
raise ValueError("TEXT_MODEL_API_KEY not configured.")
# The transcription passed here could be JSON, so we format it.
formatted_transcription = format_transcription_for_llm(transcription)
# Extract existing speaker labels (e.g., SPEAKER_00, SPEAKER_01) in order of appearance
all_labels = re.findall(r'\[(SPEAKER_\d+)\]', formatted_transcription)
seen = set()
speaker_labels = [x for x in all_labels if not (x in seen or seen.add(x))]
if not speaker_labels:
return {}
# Get configurable transcript length limit
transcript_limit = SystemSetting.get_setting('transcript_length_limit', 30000)
if transcript_limit == -1:
# No limit
transcript_text = formatted_transcription
else:
transcript_text = formatted_transcription[:transcript_limit]
prompt = f"""Analyze the following transcription and identify the names of the speakers. The speakers are labeled as {', '.join(speaker_labels)}. Based on the context of the conversation, determine the most likely name for each speaker label.
Transcription:
---
{transcript_text}
---
Respond with a single JSON object where keys are the speaker labels (e.g., "SPEAKER_00") and values are the identified full names. If a name cannot be determined, use the value "Unknown".
Example:
{{
"SPEAKER_00": "John Doe",
"SPEAKER_01": "Jane Smith",
"SPEAKER_02": "Unknown"
}}
JSON Response:
"""
try:
completion = call_llm_completion(
messages=[
{"role": "system", "content": "You are an expert in analyzing conversation transcripts to identify speakers. Your response must be a single, valid JSON object."},
{"role": "user", "content": prompt}
],
temperature=0.2
)
response_content = completion.choices[0].message.content
speaker_map = safe_json_loads(response_content, {})
# Post-process the map to replace "Unknown" with an empty string
for speaker_label, identified_name in speaker_map.items():
if identified_name.strip().lower() == "unknown":
speaker_map[speaker_label] = ""
return speaker_map
except Exception as e:
current_app.logger.error(f"Error calling LLM for speaker identification: {e}")
raise
def identify_unidentified_speakers_from_text(transcription, unidentified_speakers):
"""
Uses an LLM to identify only the unidentified speakers from a transcription.
"""
if not TEXT_MODEL_API_KEY:
raise ValueError("TEXT_MODEL_API_KEY not configured.")
# The transcription passed here could be JSON, so we format it.
formatted_transcription = format_transcription_for_llm(transcription)
if not unidentified_speakers:
return {}
# Get configurable transcript length limit
transcript_limit = SystemSetting.get_setting('transcript_length_limit', 30000)
if transcript_limit == -1:
# No limit
transcript_text = formatted_transcription
else:
transcript_text = formatted_transcription[:transcript_limit]
prompt = f"""Analyze the following conversation transcript and identify the names of the UNIDENTIFIED speakers based on the context and content of their dialogue.
The speakers that need to be identified are: {', '.join(unidentified_speakers)}
Look for clues in the conversation such as:
- Names mentioned by other speakers when addressing someone
- Self-introductions or references to their own name
- Context clues about roles, relationships, or positions
- Any direct mentions of names in the dialogue
Here is the complete conversation transcript:
{transcript_text}
Based on the conversation above, identify the most likely real names for the unidentified speakers. Pay close attention to how speakers address each other and any names that are mentioned in the dialogue.
Respond with a single JSON object where keys are the speaker labels (e.g., "SPEAKER_01") and values are the identified full names. If a name cannot be determined from the conversation context, use an empty string "".
Example format:
{{
"SPEAKER_01": "Jane Smith",
"SPEAKER_03": "Bob Johnson",
"SPEAKER_05": ""
}}
JSON Response:
"""
try:
current_app.logger.info(f"[Auto-Identify] Calling LLM to identify speakers: {unidentified_speakers}")
current_app.logger.info(f"[Auto-Identify] Transcript excerpt (first 500 chars): {transcript_text[:500]}")
completion = call_llm_completion(
messages=[
{"role": "system", "content": "You are an expert in analyzing conversation transcripts to identify speakers based on contextual clues in the dialogue. Analyze the conversation carefully to find names mentioned when speakers address each other or introduce themselves. Your response must be a single, valid JSON object containing only the requested speaker identifications."},
{"role": "user", "content": prompt}
],
temperature=0.2
)
response_content = completion.choices[0].message.content
current_app.logger.info(f"[Auto-Identify] LLM Raw Response: {response_content}")
speaker_map = safe_json_loads(response_content, {})
current_app.logger.info(f"[Auto-Identify] Parsed speaker_map: {speaker_map}")
# Post-process the map to replace "Unknown" with an empty string
for speaker_label, identified_name in speaker_map.items():
if identified_name and identified_name.strip().lower() in ["unknown", "n/a", "not available", "unclear"]:
speaker_map[speaker_label] = ""
current_app.logger.info(f"[Auto-Identify] Final speaker_map after post-processing: {speaker_map}")
return speaker_map
except Exception as e:
current_app.logger.error(f"Error calling LLM for speaker identification: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,295 @@
"""
Speaker cleanup service for managing orphaned speaker voice profiles.
This module provides automatic cleanup of speaker records when their associated
recordings are deleted through auto-deletion or manual deletion processes.
By default, speaker profiles (including voice embeddings) are preserved even
when all their recordings are deleted, since embeddings are aggregated and
represent hours of manual identification work. Set DELETE_ORPHANED_SPEAKERS=true
to enable automatic cleanup of speakers with no remaining recordings.
"""
import os
import logging
import json
from datetime import datetime
from sqlalchemy import exists
from src.database import db
from src.models import Speaker, SpeakerSnippet, Recording
logger = logging.getLogger(__name__)
def cleanup_orphaned_speakers(dry_run=False):
"""
Clean up speaker records that no longer have any associated recordings.
Only runs if DELETE_ORPHANED_SPEAKERS=true is set. By default, speaker
profiles are preserved because voice embeddings are aggregated values
that can't be reconstructed from recordings alone.
A speaker is considered orphaned when:
- It has no SpeakerSnippet records
- Its embeddings_history contains no valid recording references
Args:
dry_run (bool): If True, only report what would be deleted without actually deleting
Returns:
dict: Statistics about cleanup operation
{
'speakers_deleted': int,
'embeddings_removed': int,
'speakers_evaluated': int,
'orphaned_speakers': list of dict (if dry_run=True)
}
"""
delete_orphans = os.environ.get('DELETE_ORPHANED_SPEAKERS', 'false').lower() in ('true', '1', 'yes')
if not delete_orphans:
logger.debug("Speaker cleanup skipped (DELETE_ORPHANED_SPEAKERS is not enabled)")
return {
'speakers_deleted': 0,
'embeddings_removed': 0,
'speakers_evaluated': 0,
'orphaned_speakers': []
}
logger.info("Starting speaker cleanup process (dry_run=%s)", dry_run)
stats = {
'speakers_deleted': 0,
'embeddings_removed': 0,
'speakers_evaluated': 0,
'orphaned_speakers': []
}
try:
# Clean embeddings_history references first
embeddings_cleaned = clean_embeddings_history_references(dry_run=dry_run)
stats['embeddings_removed'] = embeddings_cleaned
# Find and process orphaned speakers
orphaned_speaker_ids = get_orphaned_speakers()
stats['speakers_evaluated'] = Speaker.query.count()
if not orphaned_speaker_ids:
logger.info("No orphaned speakers found")
return stats
logger.info("Found %d orphaned speaker(s)", len(orphaned_speaker_ids))
if dry_run:
# Report what would be deleted
for speaker_id in orphaned_speaker_ids:
speaker = Speaker.query.get(speaker_id)
if speaker:
stats['orphaned_speakers'].append({
'id': speaker.id,
'name': speaker.name,
'user_id': speaker.user_id,
'embedding_count': speaker.embedding_count
})
logger.info("Dry run: Would delete %d speakers", len(orphaned_speaker_ids))
else:
# Actually delete orphaned speakers
for speaker_id in orphaned_speaker_ids:
speaker = Speaker.query.get(speaker_id)
if speaker:
logger.debug(
"Deleting orphaned speaker: id=%d, name='%s', user_id=%d, embedding_count=%d",
speaker.id, speaker.name, speaker.user_id, speaker.embedding_count or 0
)
db.session.delete(speaker)
stats['speakers_deleted'] += 1
# Commit all deletions
db.session.commit()
logger.info("Speaker cleanup completed: %d speakers deleted", stats['speakers_deleted'])
# Warning if large number deleted
if stats['speakers_deleted'] >= 50:
logger.warning(
"Large number of speakers deleted (%d). Review cleanup logic if unexpected.",
stats['speakers_deleted']
)
return stats
except Exception as e:
db.session.rollback()
logger.error("Error during speaker cleanup: %s", str(e), exc_info=True)
raise
def clean_embeddings_history_references(dry_run=False):
"""
Clean embeddings_history JSON fields to remove references to deleted recordings.
Scans all speakers' embeddings_history and removes entries where the
recording_id no longer exists in the database.
Args:
dry_run (bool): If True, only count what would be cleaned
Returns:
int: Number of embedding references removed
"""
logger.debug("Cleaning embeddings_history references (dry_run=%s)", dry_run)
references_removed = 0
try:
# Get all speakers with embeddings_history
speakers = Speaker.query.filter(Speaker.embeddings_history.isnot(None)).all()
for speaker in speakers:
try:
# Parse embeddings_history JSON
if not speaker.embeddings_history:
continue
history = speaker.embeddings_history if isinstance(speaker.embeddings_history, list) else json.loads(speaker.embeddings_history)
if not history or not isinstance(history, list):
continue
# Filter out entries with deleted recording_ids
cleaned_history = []
for entry in history:
if not isinstance(entry, dict) or 'recording_id' not in entry:
continue
recording_id = entry['recording_id']
# Check if recording still exists
recording_exists = db.session.query(
exists().where(Recording.id == recording_id)
).scalar()
if recording_exists:
cleaned_history.append(entry)
else:
references_removed += 1
logger.debug(
"Removing deleted recording reference: speaker_id=%d, recording_id=%d",
speaker.id, recording_id
)
# Update speaker if history changed
if len(cleaned_history) < len(history):
if not dry_run:
speaker.embeddings_history = cleaned_history
logger.debug(
"Updated speaker %d embeddings_history: %d -> %d entries",
speaker.id, len(history), len(cleaned_history)
)
except (json.JSONDecodeError, TypeError, KeyError) as e:
logger.warning(
"Error processing embeddings_history for speaker %d: %s",
speaker.id, str(e)
)
continue
if not dry_run and references_removed > 0:
db.session.commit()
logger.debug("Cleaned %d embedding references", references_removed)
return references_removed
except Exception as e:
db.session.rollback()
logger.error("Error cleaning embeddings_history: %s", str(e), exc_info=True)
raise
def get_orphaned_speakers(user_id=None):
"""
Get list of speaker IDs that are orphaned (no associated recordings).
A speaker is orphaned when:
- It has no SpeakerSnippet records
- After cleaning embeddings_history, it has no valid recording references
Args:
user_id (int, optional): Filter to specific user's speakers
Returns:
list: List of speaker IDs that are orphaned
"""
logger.debug("Finding orphaned speakers (user_id=%s)", user_id)
# Query for speakers with no snippets
query = Speaker.query.filter(
~exists().where(SpeakerSnippet.speaker_id == Speaker.id)
)
if user_id is not None:
query = query.filter(Speaker.user_id == user_id)
speakers_without_snippets = query.all()
orphaned_ids = []
for speaker in speakers_without_snippets:
# Check if embeddings_history has any valid recording references
has_valid_recordings = False
if speaker.embeddings_history:
try:
history = speaker.embeddings_history if isinstance(speaker.embeddings_history, list) else json.loads(speaker.embeddings_history)
if history and isinstance(history, list):
for entry in history:
if isinstance(entry, dict) and 'recording_id' in entry:
recording_id = entry['recording_id']
# Check if this recording exists
recording_exists = db.session.query(
exists().where(Recording.id == recording_id)
).scalar()
if recording_exists:
has_valid_recordings = True
break
except (json.JSONDecodeError, TypeError, KeyError):
pass
# If no snippets AND no valid recording references, it's orphaned
if not has_valid_recordings:
orphaned_ids.append(speaker.id)
logger.debug(
"Speaker %d ('%s') is orphaned: no snippets, no valid recordings",
speaker.id, speaker.name
)
return orphaned_ids
def get_speaker_cleanup_statistics():
"""
Get statistics about speaker data for monitoring.
Returns:
dict: Statistics about speakers
{
'total_speakers': int,
'speakers_with_snippets': int,
'speakers_with_embeddings': int,
'potential_orphans': int
}
"""
stats = {
'total_speakers': Speaker.query.count(),
'speakers_with_snippets': db.session.query(Speaker.id).join(
SpeakerSnippet, Speaker.id == SpeakerSnippet.speaker_id
).distinct().count(),
'speakers_with_embeddings': Speaker.query.filter(
Speaker.average_embedding.isnot(None)
).count(),
'potential_orphans': len(get_orphaned_speakers())
}
return stats

View File

@@ -0,0 +1,453 @@
"""
Speaker Embedding Matcher Service.
This service handles voice embedding comparison and matching for speaker identification.
It provides functions to:
- Serialize/deserialize speaker embeddings for database storage
- Calculate cosine similarity between voice embeddings
- Find matching speakers based on voice similarity
- Update speaker profiles with new embeddings
- Calculate confidence scores for speaker profiles
Uses 256-dimensional embeddings from WhisperX diarization.
"""
import json
import numpy as np
from datetime import datetime
try:
from sklearn.metrics.pairwise import cosine_similarity
except ImportError:
cosine_similarity = None
from src.database import db
from src.models import Speaker
def serialize_embedding(embedding_array):
"""
Convert numpy array or list to binary for database storage.
Args:
embedding_array: numpy array or list of floats (256 dimensions)
Returns:
bytes: Binary representation (1,024 bytes for 256 × float32)
"""
return np.array(embedding_array, dtype=np.float32).tobytes()
def deserialize_embedding(binary_data):
"""
Convert binary data back to numpy array.
Args:
binary_data: bytes from database (1,024 bytes)
Returns:
numpy.ndarray: 256-dimensional float32 array
"""
return np.frombuffer(binary_data, dtype=np.float32)
def calculate_similarity(embedding1, embedding2):
"""
Compute cosine similarity between two 256-dimensional voice embeddings.
Args:
embedding1: numpy array, list, or binary data
embedding2: numpy array, list, or binary data
Returns:
float: Similarity score (0-1, where 1 is identical)
"""
# Convert to numpy arrays if needed
e1 = np.array(embedding1, dtype=np.float32).reshape(1, -1)
e2 = np.array(embedding2, dtype=np.float32).reshape(1, -1)
# Cosine similarity returns values from -1 to 1
# For voice embeddings, we typically see 0.6-0.99 range
return float(cosine_similarity(e1, e2)[0][0])
def find_matching_speakers(target_embedding, user_id, threshold=0.70):
"""
Find speakers matching a target voice embedding for a specific user.
Args:
target_embedding: The voice embedding to match against (256-dim array/list)
user_id: User ID to search within
threshold: Minimum similarity score (0-1, default 0.70 = 70%)
Returns:
list: Sorted list of matching speakers with scores
[{'speaker_id': 5, 'name': 'John', 'similarity': 85.3, 'confidence': 0.92}, ...]
"""
# Get all speakers with embeddings for this user
speakers = Speaker.query.filter_by(user_id=user_id).filter(
Speaker.average_embedding.isnot(None)
).all()
if not speakers:
return []
matches = []
for speaker in speakers:
try:
# Deserialize and compare
speaker_emb = deserialize_embedding(speaker.average_embedding)
similarity = calculate_similarity(target_embedding, speaker_emb)
if similarity >= threshold:
matches.append({
'speaker_id': speaker.id,
'name': speaker.name,
'similarity': round(similarity * 100, 1), # Convert to percentage
'confidence': speaker.confidence_score or 0.5,
'embedding_count': speaker.embedding_count or 0
})
except Exception as e:
# Skip speakers with corrupted embeddings
continue
# Sort by similarity (highest first)
return sorted(matches, key=lambda x: x['similarity'], reverse=True)
def update_speaker_embedding(speaker, new_embedding, recording_id):
"""
Update a speaker's average embedding and history with a new sample.
Uses weighted moving average to update the profile:
- New embeddings get 30% weight
- Existing average gets 70% weight
Args:
speaker: Speaker model instance
new_embedding: New voice embedding (256-dim array/list)
recording_id: ID of the recording this embedding came from
Returns:
float: Similarity between new embedding and previous average (None if first)
"""
new_emb_array = np.array(new_embedding, dtype=np.float32)
similarity_to_avg = None
if speaker.average_embedding is None:
# First embedding for this speaker
speaker.average_embedding = serialize_embedding(new_emb_array)
speaker.embedding_count = 1
speaker.embeddings_history = [{
'recording_id': recording_id,
'timestamp': datetime.utcnow().isoformat(),
'similarity': 100.0 # Perfect match to itself
}]
else:
# Update existing average
current_avg = deserialize_embedding(speaker.average_embedding)
similarity_to_avg = calculate_similarity(new_emb_array, current_avg)
# Weighted average: 30% new, 70% existing
# This prevents sudden shifts while still adapting to voice changes
weight = 0.3
updated_avg = (1 - weight) * current_avg + weight * new_emb_array
speaker.average_embedding = serialize_embedding(updated_avg)
speaker.embedding_count += 1
# Add to history (keep last 10 entries)
history = speaker.embeddings_history or []
history.append({
'recording_id': recording_id,
'timestamp': datetime.utcnow().isoformat(),
'similarity': round(similarity_to_avg * 100, 1)
})
speaker.embeddings_history = history[-10:] # Keep most recent 10
# Recalculate confidence score
speaker.confidence_score = calculate_confidence(speaker)
# Commit changes
db.session.commit()
return similarity_to_avg
def calculate_confidence(speaker):
"""
Calculate confidence score based on embedding consistency.
Confidence is based on:
- Number of samples (more is better)
- Consistency of embeddings (high similarity scores = high confidence)
Args:
speaker: Speaker model instance with embeddings_history
Returns:
float: Confidence score (0-1)
"""
if speaker.embedding_count is None or speaker.embedding_count < 1:
return 0.0
if speaker.embedding_count == 1:
return 0.5 # Medium confidence with single sample
# Get recent similarity scores from history
history = speaker.embeddings_history or []
if len(history) < 2:
return 0.5
# Use last 5 samples
recent_history = history[-5:]
similarities = [h.get('similarity', 0) / 100.0 for h in recent_history]
# Average similarity to the profile
avg_similarity = sum(similarities) / len(similarities)
# Penalize if we have very few samples
sample_factor = min(1.0, speaker.embedding_count / 5.0)
# Confidence = average similarity × sample factor
confidence = avg_similarity * sample_factor
return min(1.0, max(0.0, confidence))
def get_speaker_voice_profile_summary(speaker):
"""
Get a human-readable summary of a speaker's voice profile.
Args:
speaker: Speaker model instance
Returns:
dict: Profile summary with statistics and status
"""
if not speaker.average_embedding:
return {
'has_profile': False,
'message': 'No voice profile yet'
}
return {
'has_profile': True,
'embedding_count': speaker.embedding_count or 0,
'confidence_score': speaker.confidence_score or 0.0,
'confidence_level': _get_confidence_level(speaker.confidence_score),
'last_updated': speaker.embeddings_history[-1]['timestamp'] if speaker.embeddings_history else None,
'recordings': len(speaker.embeddings_history or [])
}
def _get_confidence_level(score):
"""
Convert numeric confidence score to human-readable level.
Args:
score: float (0-1)
Returns:
str: 'low', 'medium', or 'high'
"""
if score is None or score < 0.6:
return 'low'
elif score < 0.8:
return 'medium'
else:
return 'high'
# Threshold mapping for auto-labelling
AUTO_LABEL_THRESHOLDS = {
'low': 0.3, # Aggressive, may have more false positives
'medium': 0.6, # Default, balanced approach
'high': 0.8 # Only auto-label well-established speakers
}
# Base similarity threshold for finding matches (70%)
BASE_SIMILARITY_THRESHOLD = 0.70
# Ambiguity threshold: if top 2 matches are within 5% similarity, skip
AMBIGUITY_MARGIN = 0.05
def apply_auto_speaker_labels(recording, user):
"""
Automatically label speakers in a recording based on voice profile matching.
This function matches speaker embeddings from the recording against the user's
saved speaker profiles and returns a mapping of generic labels to speaker names.
Args:
recording: Recording model instance with speaker_embeddings
user: User model instance with auto_speaker_labelling settings
Returns:
dict: Mapping of {SPEAKER_XX: speaker_name} for matched speakers,
or empty dict if auto-labelling is disabled or no matches found
"""
# Check if user has auto-labelling enabled
if not user.auto_speaker_labelling:
return {}
# Check if recording has speaker embeddings
if not recording.speaker_embeddings:
return {}
# Get the user's threshold setting
threshold_setting = user.auto_speaker_labelling_threshold or 'medium'
confidence_threshold = AUTO_LABEL_THRESHOLDS.get(threshold_setting, AUTO_LABEL_THRESHOLDS['medium'])
speaker_map = {}
embeddings = recording.speaker_embeddings
for speaker_label, embedding_data in embeddings.items():
# embedding_data should be a list of floats (256 dimensions)
if not embedding_data or not isinstance(embedding_data, list):
continue
# Find matching speakers with base similarity threshold
matches = find_matching_speakers(
target_embedding=embedding_data,
user_id=user.id,
threshold=BASE_SIMILARITY_THRESHOLD
)
if not matches:
continue
# Check if the best match exceeds the user's confidence threshold
best_match = matches[0]
best_similarity = best_match['similarity'] / 100.0 # Convert from percentage
if best_similarity < confidence_threshold:
continue
# Check for ambiguity: if top 2 matches are within 5% similarity, skip
if len(matches) >= 2:
second_similarity = matches[1]['similarity'] / 100.0
if (best_similarity - second_similarity) <= AMBIGUITY_MARGIN:
# Ambiguous - top 2 matches too close
continue
# We have a clear winner - add to speaker map
speaker_map[speaker_label] = best_match['name']
return speaker_map
def apply_speaker_names_to_transcription(recording, speaker_map):
"""
Apply speaker name mappings to a recording's transcription.
This function updates the transcription JSON by replacing generic speaker
labels (SPEAKER_00, SPEAKER_01, etc.) with actual speaker names, and
updates the recording's participants list.
Args:
recording: Recording model instance with transcription
speaker_map: Dict mapping {SPEAKER_XX: speaker_name}
Returns:
bool: True if changes were made, False otherwise
"""
import logging
logger = logging.getLogger(__name__)
if not speaker_map or not recording.transcription:
logger.warning(f"Auto-label: No speaker_map or transcription (map={bool(speaker_map)}, trans={bool(recording.transcription)})")
return False
try:
# Parse transcription as JSON array: [{speaker, sentence, start_time, end_time}, ...]
segments = json.loads(recording.transcription)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Auto-label: Failed to parse transcription as JSON: {e}")
return False
if not isinstance(segments, list) or not segments:
logger.warning(f"Auto-label: Transcription not in expected array format")
return False
# Track which speakers were renamed
renamed_speakers = set()
# Update speaker labels in segments
for segment in segments:
if 'speaker' in segment and segment['speaker'] in speaker_map:
segment['speaker'] = speaker_map[segment['speaker']]
renamed_speakers.add(segment['speaker'])
if not renamed_speakers:
logger.warning(f"Auto-label: No speakers matched in segments")
return False
logger.info(f"Auto-label: Applied names to {len(renamed_speakers)} speakers: {renamed_speakers}")
# Update participants field
all_speakers = set(s.get('speaker') for s in segments if 'speaker' in s)
if all_speakers:
recording.participants = ', '.join(sorted(all_speakers))
# Save updated transcription
recording.transcription = json.dumps(segments)
db.session.commit()
return True
def update_speaker_profiles_from_recording(recording, speaker_map, user):
"""
Update speaker voice profiles with new embeddings from a recording.
For each successfully matched speaker, this function updates their
average embedding and increments their usage count.
Args:
recording: Recording model instance with speaker_embeddings
speaker_map: Dict mapping {SPEAKER_XX: speaker_name} that was applied
user: User model instance
Returns:
int: Number of speaker profiles updated
"""
if not speaker_map or not recording.speaker_embeddings:
return 0
updated_count = 0
embeddings = recording.speaker_embeddings
for speaker_label, speaker_name in speaker_map.items():
if speaker_label not in embeddings:
continue
embedding_data = embeddings[speaker_label]
if not embedding_data or not isinstance(embedding_data, list):
continue
# Find the speaker profile
speaker = Speaker.query.filter_by(
user_id=user.id,
name=speaker_name
).first()
if not speaker:
continue
try:
# Update the speaker's embedding with the new sample
update_speaker_embedding(speaker, embedding_data, recording.id)
# Update usage tracking
speaker.use_count = (speaker.use_count or 0) + 1
speaker.last_used = datetime.utcnow()
updated_count += 1
except Exception:
# Skip if embedding update fails
continue
if updated_count > 0:
db.session.commit()
return updated_count

View File

@@ -0,0 +1,228 @@
"""
Shared speaker identification service.
Provides LLM-based speaker identification from transcript context,
used by both the web UI (recordings.py) and REST API (api_v1.py).
"""
import os
import re
import json
from flask import current_app
def identify_speakers_from_transcript(transcription_data, user_id):
"""
Identify speakers in a transcription using an LLM.
Args:
transcription_data: List of transcript segments (already parsed JSON).
user_id: Current user's ID (for token tracking).
Returns:
dict mapping original speaker labels to identified names.
Values are empty string "" for unidentified speakers.
Raises:
ValueError: If LLM API key is not configured.
Exception: On LLM call failure.
"""
from src.services.llm import call_llm_completion
from src.utils import safe_json_loads
from src.models import SystemSetting
# Extract unique speakers in order of appearance
seen_speakers = set()
unique_speakers = []
for segment in transcription_data:
speaker = segment.get('speaker')
if speaker and speaker not in seen_speakers:
seen_speakers.add(speaker)
unique_speakers.append(speaker)
if not unique_speakers:
return {}
# Normalize all labels to SPEAKER_XX format for the LLM
speaker_to_label = {}
for idx, speaker in enumerate(unique_speakers):
speaker_to_label[speaker] = f'SPEAKER_{str(idx).zfill(2)}'
# Create temporary transcript with normalized labels
formatted_lines = []
for segment in transcription_data:
original_speaker = segment.get('speaker')
label = speaker_to_label.get(original_speaker, 'Unknown Speaker')
sentence = segment.get('sentence', '')
formatted_lines.append(f"[{label}]: {sentence}")
formatted_transcription = "\n".join(formatted_lines)
speaker_labels = list(speaker_to_label.values())
current_app.logger.info(f"[Auto-Identify] Formatted transcript (first 500 chars): {formatted_transcription[:500]}")
current_app.logger.info(f"[Auto-Identify] Speaker labels: {speaker_labels}")
# Apply configurable transcript length limit
transcript_limit = SystemSetting.get_setting('transcript_length_limit', 30000)
if transcript_limit == -1:
transcript_text = formatted_transcription
else:
transcript_text = formatted_transcription[:transcript_limit]
prompt = f"""Analyse cette transcription de conversation et identifie les noms des locuteurs à partir du contexte et du contenu de leurs dialogues.
Les locuteurs à identifier sont : {', '.join(speaker_labels)}
Indices à chercher :
- Noms mentionnés par d'autres locuteurs quand ils s'adressent à quelqu'un
- Présentations ou références à son propre nom
- Indices contextuels sur les rôles, relations ou postes
- Toute mention directe de noms dans le dialogue
Transcription complète :
{transcript_text}
À partir de cette conversation, identifie les noms les plus probables pour chaque locuteur. Porte une attention particulière à la façon dont les locuteurs s'adressent les uns aux autres.
Réponds avec un seul objet JSON où les clés sont les étiquettes de locuteurs (ex. "SPEAKER_01") et les valeurs sont les noms complets identifiés. Si un nom ne peut pas être déterminé, utilise une chaîne vide "".
Exemple :
{{
"SPEAKER_01": "Marie Lavoie",
"SPEAKER_03": "Jean Tremblay",
"SPEAKER_05": ""
}}
Réponse JSON :
"""
current_app.logger.info("[Auto-Identify] Calling LLM")
use_schema = os.environ.get('AUTO_IDENTIFY_RESPONSE_SCHEMA', '').strip() in ('1', 'true', 'yes')
system_msg = (
"You are an expert in analyzing conversation transcripts to identify speakers "
"based on contextual clues in the dialogue. Analyze the conversation carefully "
"to find names mentioned when speakers address each other or introduce themselves. "
"Your response must be a single, valid JSON object containing only the requested "
"speaker identifications."
)
response_content = None
if use_schema:
# Build JSON schema response format with constrained keys
schema_properties = {label: {"type": "string"} for label in speaker_labels}
schema_response_format = {
"type": "json_schema",
"json_schema": {
"name": "speaker_identification",
"strict": True,
"schema": {
"type": "object",
"properties": schema_properties,
"required": speaker_labels,
"additionalProperties": False
}
}
}
schema_prompt = prompt + f"\n\nIMPORTANT: Your JSON response must contain exactly these keys: {', '.join(speaker_labels)}"
try:
current_app.logger.info("[Auto-Identify] Trying json_schema response format")
completion = call_llm_completion(
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": schema_prompt}
],
temperature=0.2,
response_format=schema_response_format,
user_id=user_id,
operation_type='speaker_identification'
)
response_content = completion.choices[0].message.content
current_app.logger.info(f"[Auto-Identify] LLM Raw Response (schema mode): {response_content}")
except Exception as schema_err:
current_app.logger.warning(f"[Auto-Identify] json_schema mode failed, falling back to json_object: {schema_err}")
response_content = None
if response_content is None:
completion = call_llm_completion(
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt}
],
temperature=0.2,
user_id=user_id,
operation_type='speaker_identification'
)
response_content = completion.choices[0].message.content
current_app.logger.info(f"[Auto-Identify] LLM Raw Response: {response_content}")
identified_map = safe_json_loads(response_content, {})
current_app.logger.info(f"[Auto-Identify] Parsed identified_map: {identified_map}")
# --- Sanitize identified_map ---
identified_map = _sanitize_identified_map(identified_map, speaker_labels)
current_app.logger.info(f"[Auto-Identify] Sanitized identified_map: {identified_map}")
# Map back to original speaker labels
final_speaker_map = {}
for original_speaker, temp_label in speaker_to_label.items():
if temp_label in identified_map:
final_speaker_map[original_speaker] = identified_map[temp_label]
current_app.logger.info(f"[Auto-Identify] Final speaker_map: {final_speaker_map}")
return final_speaker_map
def _sanitize_identified_map(identified_map, speaker_labels):
"""
Clean up LLM output: handle inverted maps, strip commentary,
clear placeholders, etc.
"""
speaker_label_re = re.compile(r'^SPEAKER_\d{2}$')
# Detect inverted map ({name: "SPEAKER_XX"}) and flip it
if identified_map and all(
speaker_label_re.match(str(v)) for v in identified_map.values() if v
) and not any(speaker_label_re.match(str(k)) for k in identified_map.keys()):
current_app.logger.warning("[Auto-Identify] Detected inverted map, flipping keys/values")
identified_map = {v: k for k, v in identified_map.items() if v}
sanitized = {}
for speaker_label, identified_name in identified_map.items():
# Skip entries whose key isn't a valid SPEAKER_XX label
if not speaker_label_re.match(str(speaker_label)):
continue
if not identified_name or not isinstance(identified_name, str):
sanitized[speaker_label] = ""
continue
name = identified_name.strip()
# Clear generic placeholders
if name.lower() in ["unknown", "n/a", "not available", "unclear", "unidentified", ""]:
sanitized[speaker_label] = ""
continue
# Clear label-to-label entries (e.g. "SPEAKER_01": "SPEAKER_02")
if speaker_label_re.match(name):
sanitized[speaker_label] = ""
continue
# Strip parenthetical content: "John (the host)" -> "John"
name = re.sub(r'\s*\([^)]*\)', '', name).strip()
# Take first name segment before comma, semicolon, or slash
name = re.split(r'[,;/]', name)[0].strip()
# Collapse whitespace
name = re.sub(r'\s+', ' ', name)
# Final check: if result still matches SPEAKER_XX, clear it
if speaker_label_re.match(name) or not name:
sanitized[speaker_label] = ""
continue
sanitized[speaker_label] = name
return sanitized

View File

@@ -0,0 +1,226 @@
"""
Speaker Merge Service.
This service handles merging multiple speaker profiles into one.
Useful when users accidentally create duplicate speakers for the same person.
When speakers are merged:
- Voice embeddings are combined using weighted average
- All snippets are transferred to the target speaker
- Usage statistics are combined
- Source speakers are deleted
- Confidence score is recalculated
"""
import numpy as np
from src.database import db
from src.models import Speaker, SpeakerSnippet
from src.services.speaker_embedding_matcher import (
serialize_embedding,
deserialize_embedding,
calculate_confidence
)
def merge_speakers(target_id, source_ids, user_id):
"""
Merge multiple speaker profiles into one target speaker.
All embeddings, snippets, and usage data from source speakers are
combined into the target speaker. Source speakers are then deleted.
Args:
target_id: ID of the speaker to keep (receives all merged data)
source_ids: List of speaker IDs to merge into target
user_id: ID of the user (for security check)
Returns:
Speaker: The updated target speaker
Raises:
ValueError: If speakers don't exist or don't belong to user
"""
# Validate target speaker
target = Speaker.query.filter_by(id=target_id, user_id=user_id).first()
if not target:
raise ValueError(f"Target speaker {target_id} not found or doesn't belong to user")
# Validate source speakers
sources = Speaker.query.filter(
Speaker.id.in_(source_ids),
Speaker.user_id == user_id
).all()
if len(sources) == 0:
raise ValueError("No valid source speakers found")
if len(sources) != len(source_ids):
raise ValueError("Some source speakers don't exist or don't belong to user")
# Can't merge a speaker with itself
if target_id in source_ids:
raise ValueError("Cannot merge a speaker with itself")
# Combine embeddings
_combine_embeddings(target, sources)
# Transfer snippets
for source in sources:
SpeakerSnippet.query.filter_by(speaker_id=source.id).update(
{'speaker_id': target_id}
)
# Combine usage statistics
for source in sources:
target.use_count += source.use_count
# Update last_used to most recent
if source.last_used and (not target.last_used or source.last_used > target.last_used):
target.last_used = source.last_used
# Combine embedding histories
if source.embeddings_history:
target_history = target.embeddings_history or []
source_history = source.embeddings_history or []
combined_history = target_history + source_history
# Sort by timestamp (most recent last) and keep last 10
try:
combined_history.sort(key=lambda x: x.get('timestamp', ''))
target.embeddings_history = combined_history[-10:]
except:
# If sorting fails, just concatenate and truncate
target.embeddings_history = (target_history + source_history)[-10:]
# Recalculate confidence score
target.confidence_score = calculate_confidence(target)
# Delete source speakers
for source in sources:
db.session.delete(source)
# Commit all changes
db.session.commit()
return target
def _combine_embeddings(target, sources):
"""
Combine embeddings from multiple speakers using weighted average.
Weight is based on embedding_count (more samples = more weight).
Args:
target: Target Speaker instance
sources: List of source Speaker instances
"""
all_embeddings = []
all_counts = []
# Add target's embedding if it exists
if target.average_embedding:
all_embeddings.append(deserialize_embedding(target.average_embedding))
all_counts.append(target.embedding_count or 1)
# Add all source embeddings
for source in sources:
if source.average_embedding:
all_embeddings.append(deserialize_embedding(source.average_embedding))
all_counts.append(source.embedding_count or 1)
if not all_embeddings:
# No embeddings to combine
return
# Calculate weighted average
total_count = sum(all_counts)
weights = [c / total_count for c in all_counts]
combined_emb = np.average(all_embeddings, axis=0, weights=weights)
# Update target
target.average_embedding = serialize_embedding(combined_emb)
target.embedding_count = total_count
def preview_merge(target_id, source_ids, user_id):
"""
Preview what a merge would look like without executing it.
Args:
target_id: ID of the target speaker
source_ids: List of source speaker IDs
user_id: ID of the user
Returns:
dict: Preview of the merge results
{
'target_name': '...',
'source_names': [...],
'combined_use_count': 123,
'combined_embedding_count': 45,
'total_snippets': 67
}
"""
# Validate speakers
target = Speaker.query.filter_by(id=target_id, user_id=user_id).first()
if not target:
raise ValueError("Target speaker not found")
sources = Speaker.query.filter(
Speaker.id.in_(source_ids),
Speaker.user_id == user_id
).all()
if len(sources) == 0:
raise ValueError("No valid source speakers found")
# Calculate combined statistics
combined_use_count = target.use_count
combined_embedding_count = target.embedding_count or 0
total_snippets = SpeakerSnippet.query.filter_by(speaker_id=target_id).count()
source_names = []
for source in sources:
combined_use_count += source.use_count
combined_embedding_count += (source.embedding_count or 0)
total_snippets += SpeakerSnippet.query.filter_by(speaker_id=source.id).count()
source_names.append(source.name)
return {
'target_name': target.name,
'source_names': source_names,
'combined_use_count': combined_use_count,
'combined_embedding_count': combined_embedding_count,
'total_snippets': total_snippets,
'has_embeddings': target.average_embedding is not None or any(s.average_embedding for s in sources)
}
def can_merge_speakers(speaker_ids, user_id):
"""
Check if speakers can be merged (all belong to same user, no duplicates).
Args:
speaker_ids: List of speaker IDs
user_id: ID of the user
Returns:
tuple: (bool, str) - (can_merge, error_message)
"""
if len(speaker_ids) < 2:
return False, "Need at least 2 speakers to merge"
if len(speaker_ids) != len(set(speaker_ids)):
return False, "Duplicate speaker IDs provided"
speakers = Speaker.query.filter(
Speaker.id.in_(speaker_ids),
Speaker.user_id == user_id
).all()
if len(speakers) != len(speaker_ids):
return False, "Some speakers don't exist or don't belong to user"
return True, ""

View File

@@ -0,0 +1,371 @@
"""
Speaker Snippets Service.
This service handles the extraction and management of representative speech snippets
from recordings. Snippets provide context when viewing speaker profiles and help
users verify speaker identifications.
Key functions:
- Extract snippets when speakers are identified in recordings
- Retrieve snippets for display in speaker profiles
- Clean up old snippets to prevent database bloat
"""
import json
from src.database import db
from src.models import Speaker, SpeakerSnippet, Recording
MAX_SNIPPETS_PER_SPEAKER = 7
MAX_SNIPPETS_PER_RECORDING = 2
def create_speaker_snippets(recording_id, speaker_map):
"""
Extract and store representative snippets for each identified speaker.
This function is called after a user saves speaker identifications in a recording.
It extracts up to MAX_SNIPPETS_PER_RECORDING quotes per speaker from this recording,
and enforces a global cap of MAX_SNIPPETS_PER_SPEAKER by evicting the oldest.
Args:
recording_id: ID of the recording
speaker_map: Dict mapping SPEAKER_XX to speaker info
{'SPEAKER_00': {'name': 'John Doe', 'isMe': False}, ...}
Returns:
int: Number of snippets created
"""
recording = Recording.query.get(recording_id)
if not recording or not recording.transcription:
return 0
try:
transcript = json.loads(recording.transcription)
except (json.JSONDecodeError, TypeError):
return 0
# Build a reverse map: assigned name -> speaker_info
# After transcript is saved, segment['speaker'] contains the real name,
# not the original SPEAKER_XX label. We need to match by name too.
name_to_info = {}
for label, info in speaker_map.items():
name = info.get('name', '').strip()
if name and not name.startswith('SPEAKER_'):
name_to_info[name] = info
# Collect candidates per speaker: (speaker_obj, segment_idx, text, timestamp)
candidates = {} # speaker_id -> list of (segment_idx, text, timestamp)
for segment_idx, segment in enumerate(transcript):
speaker_field = segment.get('speaker')
if not speaker_field:
continue
# Try matching by original label first, then by assigned name
if speaker_field in speaker_map:
speaker_info = speaker_map[speaker_field]
speaker_name = speaker_info.get('name')
elif speaker_field in name_to_info:
speaker_name = speaker_field
else:
continue
if not speaker_name or speaker_name.startswith('SPEAKER_'):
continue
# Find the speaker in database
speaker = Speaker.query.filter_by(
user_id=recording.user_id,
name=speaker_name
).first()
if not speaker:
continue
text = segment.get('sentence', '').strip()
if len(text) < 10:
continue
if speaker.id not in candidates:
candidates[speaker.id] = []
candidates[speaker.id].append((segment_idx, text[:200], segment.get('start_time')))
# Delete existing snippets for this recording (re-save replaces them)
SpeakerSnippet.query.filter_by(recording_id=recording_id).delete()
snippets_created = 0
for speaker_id, segs in candidates.items():
# Pick up to MAX_SNIPPETS_PER_RECORDING spread across the transcript
if len(segs) <= MAX_SNIPPETS_PER_RECORDING:
chosen = segs
else:
# Evenly sample from the segments
step = len(segs) / MAX_SNIPPETS_PER_RECORDING
chosen = [segs[int(i * step)] for i in range(MAX_SNIPPETS_PER_RECORDING)]
for segment_idx, text_snippet, timestamp in chosen:
# Evict oldest if at global cap
global_count = SpeakerSnippet.query.filter_by(speaker_id=speaker_id).count()
if global_count >= MAX_SNIPPETS_PER_SPEAKER:
oldest = SpeakerSnippet.query.filter_by(speaker_id=speaker_id)\
.order_by(SpeakerSnippet.created_at.asc()).first()
if oldest:
db.session.delete(oldest)
db.session.flush()
snippet = SpeakerSnippet(
speaker_id=speaker_id,
recording_id=recording_id,
segment_index=segment_idx,
text_snippet=text_snippet,
timestamp=timestamp,
)
db.session.add(snippet)
snippets_created += 1
# Flush after each speaker batch to keep counts accurate
db.session.flush()
if snippets_created > 0:
db.session.commit()
return snippets_created
def _generate_dynamic_snippets(speaker_id, limit=3):
"""
Dynamically generate audio snippets from a speaker's recent recordings.
This function finds short audio segments (3-4 seconds) from recent recordings
where the speaker appears. These can be played back to verify speaker identity.
Args:
speaker_id: ID of the speaker
limit: Maximum number of snippets to return (default 3)
Returns:
list: List of snippet dictionaries with audio segment information
[{'recording_id': 123, 'start_time': 45.2, 'duration': 3.5, ...}, ...]
"""
# Get the speaker
speaker = Speaker.query.get(speaker_id)
if not speaker:
return []
# Find recordings that have this speaker's name in transcription
# We'll look at the last 10 recordings and extract snippets from them
recordings = Recording.query.filter_by(user_id=speaker.user_id)\
.filter(Recording.transcription.isnot(None))\
.filter(Recording.transcription != '')\
.filter(Recording.audio_deleted_at.is_(None))\
.order_by(Recording.created_at.desc())\
.limit(10).all()
snippets = []
for recording in recordings:
if len(snippets) >= limit:
break
try:
# Parse transcription JSON
transcript = json.loads(recording.transcription)
if not isinstance(transcript, list):
continue
# Find segments where this speaker appears
speaker_segments = []
for idx, segment in enumerate(transcript):
# Check if segment has speaker identification matching our speaker's name
speaker_label = segment.get('speaker')
# In identified transcripts, the speaker field contains the actual name
if speaker_label != speaker.name:
continue
start_time = segment.get('start_time')
end_time = segment.get('end_time')
if start_time is None or end_time is None:
continue
duration = end_time - start_time
# Skip very short segments (less than 2 seconds)
if duration < 2.0:
continue
speaker_segments.append({
'index': idx,
'start_time': start_time,
'end_time': end_time,
'duration': duration,
'text': segment.get('sentence', '').strip()[:100] # Preview text
})
if not speaker_segments:
continue
# Take snippets from middle portions (skip first and last 10%)
total_segments = len(speaker_segments)
if total_segments > 4:
# Skip first and last 10%
start_idx = max(1, int(total_segments * 0.1))
end_idx = min(total_segments - 1, int(total_segments * 0.9))
middle_segments = speaker_segments[start_idx:end_idx]
else:
middle_segments = speaker_segments
# Take 1 snippet per recording from the middle
if middle_segments:
# Pick a segment from the middle
middle_idx = len(middle_segments) // 2
segment = middle_segments[middle_idx]
# Limit audio snippet to 3-4 seconds
snippet_duration = min(4.0, segment['duration'])
snippets.append({
'id': None, # Dynamic snippet, no database ID
'speaker_id': speaker_id,
'recording_id': recording.id,
'start_time': segment['start_time'],
'duration': snippet_duration,
'text': segment['text'], # Preview text for context
'recording_title': recording.title or 'Untitled Recording',
'created_at': recording.created_at.isoformat() if recording.created_at else None
})
except (json.JSONDecodeError, TypeError, KeyError) as e:
# Skip recordings with invalid transcription format
continue
return snippets
def get_speaker_snippets(speaker_id, limit=3):
"""
Get recent audio snippets for a speaker.
Returns short audio segments (3-4 seconds) from recent recordings where this
speaker appears. These audio snippets can be played to verify speaker identity.
Args:
speaker_id: ID of the speaker
limit: Maximum number of snippets to return (default 3)
Returns:
list: List of snippet dictionaries with audio segment information
[{'recording_id': 123, 'start_time': 45.2, 'duration': 3.5, ...}, ...]
"""
# Always dynamically generate audio snippets from recent recordings
return _generate_dynamic_snippets(speaker_id, limit)
def get_snippets_by_recording(recording_id, speaker_id):
"""
Get all snippets for a specific speaker in a specific recording.
Args:
recording_id: ID of the recording
speaker_id: ID of the speaker
Returns:
list: List of snippet dictionaries
"""
snippets = SpeakerSnippet.query.filter_by(
recording_id=recording_id,
speaker_id=speaker_id
).order_by(SpeakerSnippet.segment_index).all()
return [snippet.to_dict() for snippet in snippets]
def cleanup_old_snippets(speaker_id, keep=10):
"""
Clean up old snippets for a speaker, keeping only the most recent ones.
Args:
speaker_id: ID of the speaker
keep: Number of snippets to keep (default 10)
Returns:
int: Number of snippets deleted
"""
# Get all snippets for this speaker, ordered by creation date
all_snippets = SpeakerSnippet.query.filter_by(speaker_id=speaker_id)\
.order_by(SpeakerSnippet.created_at.desc()).all()
if len(all_snippets) <= keep:
return 0
# Delete old snippets beyond the keep limit
snippets_to_delete = all_snippets[keep:]
deleted_count = 0
for snippet in snippets_to_delete:
db.session.delete(snippet)
deleted_count += 1
if deleted_count > 0:
db.session.commit()
return deleted_count
def delete_snippets_for_recording(recording_id):
"""
Delete all snippets associated with a recording.
This is typically called when a recording is deleted or reprocessed.
Args:
recording_id: ID of the recording
Returns:
int: Number of snippets deleted
"""
deleted_count = SpeakerSnippet.query.filter_by(recording_id=recording_id).delete()
db.session.commit()
return deleted_count
def get_speaker_recordings_with_snippets(speaker_id):
"""
Get a list of recordings that have snippets for this speaker.
Args:
speaker_id: ID of the speaker
Returns:
list: List of recording dictionaries with snippet counts
[{'id': 123, 'title': '...', 'snippet_count': 3, 'date': '...'}, ...]
"""
# Get distinct recordings with snippet counts
from sqlalchemy import func
recordings_with_counts = db.session.query(
Recording.id,
Recording.title,
Recording.created_at,
func.count(SpeakerSnippet.id).label('snippet_count')
).join(
SpeakerSnippet,
Recording.id == SpeakerSnippet.recording_id
).filter(
SpeakerSnippet.speaker_id == speaker_id
).group_by(
Recording.id
).order_by(
Recording.created_at.desc()
).all()
return [{
'id': r.id,
'title': r.title,
'snippet_count': r.snippet_count,
'created_at': r.created_at.isoformat() if r.created_at else None
} for r in recordings_with_counts]

View File

@@ -0,0 +1,270 @@
"""
Token usage tracking service for monitoring LLM API consumption and budget enforcement.
"""
import logging
from datetime import date, datetime, timedelta
from typing import Tuple, Optional, Dict, List
from sqlalchemy import func, extract
from src.database import db
from src.models.token_usage import TokenUsage
from src.models.user import User
logger = logging.getLogger(__name__)
class TokenTracker:
"""Service for recording and checking token usage."""
OPERATION_TYPES = [
'summarization',
'chat',
'title_generation',
'event_extraction',
'query_routing',
'query_enrichment'
]
def record_usage(
self,
user_id: int,
operation_type: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model_name: str = None,
cost: float = None
):
"""
Record token usage - upserts into daily aggregate.
Args:
user_id: User ID who made the request
operation_type: Type of operation (summarization, chat, etc.)
prompt_tokens: Number of input tokens
completion_tokens: Number of output tokens
total_tokens: Total tokens (prompt + completion)
model_name: Name of the model used
cost: API cost if available (e.g., from OpenRouter)
"""
try:
today = date.today()
# Find or create today's record for this user/operation
usage = TokenUsage.query.filter_by(
user_id=user_id,
date=today,
operation_type=operation_type
).first()
if usage:
# Update existing record
usage.prompt_tokens += prompt_tokens
usage.completion_tokens += completion_tokens
usage.total_tokens += total_tokens
usage.request_count += 1
if cost:
usage.cost += cost
if model_name:
usage.model_name = model_name # Update to latest model used
else:
# Create new record
usage = TokenUsage(
user_id=user_id,
date=today,
operation_type=operation_type,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
request_count=1,
model_name=model_name,
cost=cost or 0.0
)
db.session.add(usage)
db.session.commit()
logger.debug(f"Recorded {total_tokens} tokens for user {user_id}, operation {operation_type}")
return usage
except Exception as e:
logger.error(f"Failed to record token usage: {e}")
db.session.rollback()
return None
def get_monthly_usage(self, user_id: int, year: int = None, month: int = None) -> int:
"""Get total tokens used by a user in a given month."""
if year is None:
year = date.today().year
if month is None:
month = date.today().month
result = db.session.query(func.sum(TokenUsage.total_tokens)).filter(
TokenUsage.user_id == user_id,
extract('year', TokenUsage.date) == year,
extract('month', TokenUsage.date) == month
).scalar()
return result or 0
def get_monthly_cost(self, user_id: int, year: int = None, month: int = None) -> float:
"""Get total cost for a user in a given month."""
if year is None:
year = date.today().year
if month is None:
month = date.today().month
result = db.session.query(func.sum(TokenUsage.cost)).filter(
TokenUsage.user_id == user_id,
extract('year', TokenUsage.date) == year,
extract('month', TokenUsage.date) == month
).scalar()
return result or 0.0
def check_budget(self, user_id: int) -> Tuple[bool, float, Optional[str]]:
"""
Check if user is within budget.
Returns:
(can_proceed, usage_percentage, message)
- can_proceed: False if hard cap (100%) reached
- usage_percentage: 0-100+
- message: Warning/error message if applicable
"""
try:
user = db.session.get(User, user_id)
if not user or not user.monthly_token_budget:
return (True, 0, None) # No budget = unlimited
current_usage = self.get_monthly_usage(user_id)
budget = user.monthly_token_budget
percentage = (current_usage / budget) * 100
if percentage >= 100:
return (False, percentage,
f"Monthly token budget exceeded ({percentage:.1f}%). Contact admin for more tokens.")
elif percentage >= 80:
return (True, percentage,
f"Warning: {percentage:.1f}% of monthly token budget used.")
else:
return (True, percentage, None)
except Exception as e:
logger.error(f"Failed to check budget for user {user_id}: {e}")
# Fail open - allow the request if we can't check
return (True, 0, None)
def get_daily_stats(self, days: int = 30, user_id: int = None) -> List[Dict]:
"""Get daily token usage for charts."""
start_date = date.today() - timedelta(days=days - 1)
query = db.session.query(
TokenUsage.date,
TokenUsage.operation_type,
func.sum(TokenUsage.total_tokens).label('tokens'),
func.sum(TokenUsage.cost).label('cost')
).filter(TokenUsage.date >= start_date)
if user_id:
query = query.filter(TokenUsage.user_id == user_id)
results = query.group_by(TokenUsage.date, TokenUsage.operation_type).all()
# Organize by date
stats = {}
for r in results:
date_str = r.date.isoformat()
if date_str not in stats:
stats[date_str] = {'date': date_str, 'total': 0, 'cost': 0.0, 'by_operation': {}}
stats[date_str]['total'] += r.tokens or 0
stats[date_str]['cost'] += r.cost or 0
stats[date_str]['by_operation'][r.operation_type] = r.tokens or 0
# Fill in missing dates with zeros
all_dates = []
current = start_date
while current <= date.today():
date_str = current.isoformat()
if date_str not in stats:
stats[date_str] = {'date': date_str, 'total': 0, 'cost': 0.0, 'by_operation': {}}
all_dates.append(date_str)
current += timedelta(days=1)
return [stats[d] for d in sorted(all_dates)]
def get_monthly_stats(self, months: int = 12) -> List[Dict]:
"""Get monthly token usage for charts."""
results = db.session.query(
extract('year', TokenUsage.date).label('year'),
extract('month', TokenUsage.date).label('month'),
func.sum(TokenUsage.total_tokens).label('tokens'),
func.sum(TokenUsage.cost).label('cost')
).group_by('year', 'month').order_by('year', 'month').all()
# Get last N months
monthly_data = [
{
'year': int(r.year),
'month': int(r.month),
'tokens': r.tokens or 0,
'cost': r.cost or 0
}
for r in results
]
return monthly_data[-months:] if len(monthly_data) > months else monthly_data
def get_user_stats(self) -> List[Dict]:
"""Get per-user token usage breakdown for current month."""
today = date.today()
results = db.session.query(
User.id,
User.username,
User.monthly_token_budget,
func.sum(TokenUsage.total_tokens).label('usage'),
func.sum(TokenUsage.cost).label('cost')
).outerjoin(
TokenUsage,
(User.id == TokenUsage.user_id) &
(extract('year', TokenUsage.date) == today.year) &
(extract('month', TokenUsage.date) == today.month)
).group_by(User.id).all()
return [
{
'user_id': r.id,
'username': r.username,
'monthly_budget': r.monthly_token_budget,
'current_usage': r.usage or 0,
'cost': r.cost or 0,
'percentage': ((r.usage or 0) / r.monthly_token_budget * 100)
if r.monthly_token_budget else 0
}
for r in results
]
def get_today_usage(self, user_id: int = None) -> Dict:
"""Get today's token usage."""
today = date.today()
query = db.session.query(
func.sum(TokenUsage.total_tokens).label('tokens'),
func.sum(TokenUsage.cost).label('cost')
).filter(TokenUsage.date == today)
if user_id:
query = query.filter(TokenUsage.user_id == user_id)
result = query.first()
return {
'tokens': result.tokens or 0,
'cost': result.cost or 0
}
# Singleton instance
token_tracker = TokenTracker()

View File

@@ -0,0 +1,98 @@
"""
Transcription service package.
Provides a connector-based architecture for speech-to-text transcription
with support for multiple providers:
- OpenAI Whisper (whisper-1)
- OpenAI GPT-4o Transcribe (gpt-4o-transcribe, gpt-4o-mini-transcribe, gpt-4o-transcribe-diarize)
- Custom ASR endpoints (whisper-asr-webservice, WhisperX, etc.)
Usage:
from src.services.transcription import (
transcribe,
get_connector,
supports_diarization,
TranscriptionRequest,
TranscriptionResponse,
)
# Simple transcription using active connector
with open('audio.mp3', 'rb') as f:
request = TranscriptionRequest(
audio_file=f,
filename='audio.mp3',
diarize=True
)
response = transcribe(request)
print(response.text)
if response.segments:
for seg in response.segments:
print(f"[{seg.speaker}]: {seg.text}")
"""
from .base import (
TranscriptionCapability,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionSegment,
BaseTranscriptionConnector,
ConnectorSpecifications,
DEFAULT_SPECIFICATIONS,
)
from .exceptions import (
TranscriptionError,
ConfigurationError,
ProviderError,
AudioFormatError,
ChunkingError,
)
from .registry import (
ConnectorRegistry,
get_registry,
connector_registry,
transcribe,
get_connector,
supports_diarization,
)
from .connectors import (
OpenAIWhisperConnector,
OpenAITranscribeConnector,
ASREndpointConnector,
)
__all__ = [
# Base types
'TranscriptionCapability',
'TranscriptionRequest',
'TranscriptionResponse',
'TranscriptionSegment',
'BaseTranscriptionConnector',
'ConnectorSpecifications',
'DEFAULT_SPECIFICATIONS',
# Exceptions
'TranscriptionError',
'ConfigurationError',
'ProviderError',
'AudioFormatError',
'ChunkingError',
# Registry
'ConnectorRegistry',
'get_registry',
'connector_registry',
# Convenience functions
'transcribe',
'get_connector',
'supports_diarization',
# Connectors
'OpenAIWhisperConnector',
'OpenAITranscribeConnector',
'ASREndpointConnector',
]

View File

@@ -0,0 +1,243 @@
"""
Base classes and data types for transcription connectors.
"""
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Optional, List, Dict, Any, BinaryIO, Set, Type, FrozenSet
class TranscriptionCapability(Enum):
"""Capabilities that connectors can declare support for."""
DIARIZATION = auto() # Speaker diarization
CHUNKING = auto() # Automatic file chunking for large files
TIMESTAMPS = auto() # Word/segment timestamps
LANGUAGE_DETECTION = auto() # Auto language detection
KNOWN_SPEAKERS = auto() # Support for known speaker references (future)
SPEAKER_EMBEDDINGS = auto() # Return speaker embeddings
SPEAKER_COUNT_CONTROL = auto() # Support for min/max speaker count parameters
STREAMING = auto() # Real-time streaming transcription
@dataclass
class ConnectorSpecifications:
"""
Provider-specific constraints and requirements.
Each connector declares its constraints so the application can automatically
handle chunking, format conversion, and other preprocessing as needed.
"""
# Size constraints
max_file_size_bytes: Optional[int] = None # None = unlimited
# Duration constraints
max_duration_seconds: Optional[int] = None # None = unlimited
min_duration_for_chunking: Optional[int] = None # Provider's internal chunking threshold
# Chunking behavior
handles_chunking_internally: bool = False # Provider handles large files
requires_chunking_param: bool = False # Must send chunking_strategy param
recommended_chunk_seconds: int = 600 # 10 minutes default
# Audio format support - connector-specific codec restrictions
# None = use system defaults from get_supported_codecs()
# Set = only allow these codecs (overrides defaults)
supported_codecs: Optional[FrozenSet[str]] = None
# Codecs this connector doesn't support (removed from defaults)
# Merged with AUDIO_UNSUPPORTED_CODECS env var
unsupported_codecs: Optional[FrozenSet[str]] = None
# Default specifications for connectors that don't define their own
DEFAULT_SPECIFICATIONS = ConnectorSpecifications()
@dataclass
class TranscriptionRequest:
"""Standardized transcription request."""
audio_file: BinaryIO
filename: str
mime_type: Optional[str] = None
language: Optional[str] = None
# Diarization options
diarize: bool = False
min_speakers: Optional[int] = None
max_speakers: Optional[int] = None
known_speaker_names: Optional[List[str]] = None
# known_speaker_references: Dict mapping speaker label to either BinaryIO or data URL string
known_speaker_references: Optional[Dict[str, Any]] = None
# Advanced options
prompt: Optional[str] = None
hotwords: Optional[str] = None # Comma-separated words to bias recognition
temperature: Optional[float] = None
# Provider-specific options (passthrough)
extra_options: Dict[str, Any] = field(default_factory=dict)
@dataclass
class TranscriptionSegment:
"""Single segment of transcription with optional metadata."""
text: str
speaker: Optional[str] = None
start_time: Optional[float] = None
end_time: Optional[float] = None
confidence: Optional[float] = None
words: Optional[List[Dict[str, Any]]] = None
@dataclass
class TranscriptionResponse:
"""Standardized transcription response."""
# Core content
text: str # Plain text transcription
segments: Optional[List[TranscriptionSegment]] = None # Detailed segments
# Metadata
language: Optional[str] = None # Detected language
duration: Optional[float] = None # Audio duration in seconds
# Speaker information
speakers: Optional[List[str]] = None # List of speakers found
speaker_embeddings: Optional[Dict[str, List[float]]] = None
# Provider info
provider: str = ""
model: str = ""
# Raw response for debugging
raw_response: Optional[Dict[str, Any]] = None
def to_storage_format(self) -> str:
"""
Convert to the JSON format used for storage in database.
Returns a JSON string in the format expected by the existing codebase:
[
{
"speaker": "SPEAKER_00",
"sentence": "Text here",
"start_time": 0.0,
"end_time": 5.5
},
...
]
"""
if self.segments:
return json.dumps([
{
'speaker': seg.speaker or 'Unknown Speaker',
'sentence': seg.text,
'start_time': seg.start_time,
'end_time': seg.end_time
}
for seg in self.segments
])
# If no segments, return plain text (for non-diarized transcriptions)
return self.text
def has_diarization(self) -> bool:
"""Check if this response contains diarization data."""
if not self.segments:
return False
return any(seg.speaker for seg in self.segments)
class BaseTranscriptionConnector(ABC):
"""Abstract base class for transcription connectors."""
# Class-level capability declarations - subclasses should override
CAPABILITIES: Set[TranscriptionCapability] = set()
PROVIDER_NAME: str = "unknown"
# Provider-specific constraints - subclasses should override
SPECIFICATIONS: ConnectorSpecifications = DEFAULT_SPECIFICATIONS
def __init__(self, config: Dict[str, Any]):
"""
Initialize connector with configuration.
Args:
config: Provider-specific configuration dict
"""
self.config = config
self._validate_config()
@abstractmethod
def _validate_config(self) -> None:
"""
Validate required configuration is present.
Raises:
ConfigurationError: If required config is missing or invalid
"""
pass
@abstractmethod
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
"""
Perform transcription.
Args:
request: Standardized transcription request
Returns:
Standardized transcription response
Raises:
TranscriptionError: On transcription failure
ConfigurationError: On configuration issues
"""
pass
def supports(self, capability: TranscriptionCapability) -> bool:
"""Check if connector supports a capability."""
return capability in self.CAPABILITIES
def get_capabilities(self) -> Set[TranscriptionCapability]:
"""Get all supported capabilities."""
return self.CAPABILITIES.copy()
@property
def supports_diarization(self) -> bool:
"""Check if connector supports speaker diarization."""
return TranscriptionCapability.DIARIZATION in self.CAPABILITIES
@property
def supports_chunking(self) -> bool:
"""Check if connector supports automatic file chunking."""
return TranscriptionCapability.CHUNKING in self.CAPABILITIES
@property
def supports_speaker_count_control(self) -> bool:
"""Check if connector supports min/max speaker count parameters."""
return TranscriptionCapability.SPEAKER_COUNT_CONTROL in self.CAPABILITIES
@property
def specifications(self) -> ConnectorSpecifications:
"""Get connector specifications."""
return self.SPECIFICATIONS
@classmethod
def get_config_schema(cls) -> Dict[str, Any]:
"""
Return JSON schema for this connector's configuration.
Useful for admin UI and validation.
Returns:
JSON schema dict describing required and optional config
"""
return {}
def health_check(self) -> bool:
"""
Check if the connector is properly configured and reachable.
Returns:
True if the connector is healthy, False otherwise
"""
return True

View File

@@ -0,0 +1,15 @@
"""
Transcription connector implementations.
"""
from .openai_whisper import OpenAIWhisperConnector
from .openai_transcribe import OpenAITranscribeConnector
from .asr_endpoint import ASREndpointConnector
from .azure_openai_transcribe import AzureOpenAITranscribeConnector
__all__ = [
'OpenAIWhisperConnector',
'OpenAITranscribeConnector',
'ASREndpointConnector',
'AzureOpenAITranscribeConnector',
]

View File

@@ -0,0 +1,337 @@
"""
ASR Endpoint connector for custom self-hosted ASR services.
Supports whisper-asr-webservice, WhisperX, and other compatible ASR services
that expose a /asr endpoint.
"""
import logging
import os
import httpx
from typing import Dict, Any, Set, Optional
from ..base import (
BaseTranscriptionConnector,
TranscriptionCapability,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionSegment,
ConnectorSpecifications,
)
from ..exceptions import TranscriptionError, ConfigurationError, ProviderError
from src.config.app_config import ASR_ENABLE_CHUNKING, ASR_MAX_DURATION_SECONDS
logger = logging.getLogger(__name__)
class ASREndpointConnector(BaseTranscriptionConnector):
"""Connector for custom ASR webservice (whisper-asr-webservice, WhisperX, etc.)."""
CAPABILITIES: Set[TranscriptionCapability] = {
TranscriptionCapability.DIARIZATION,
TranscriptionCapability.TIMESTAMPS,
TranscriptionCapability.LANGUAGE_DETECTION,
TranscriptionCapability.SPEAKER_COUNT_CONTROL, # Supports min/max speakers
}
PROVIDER_NAME = "asr_endpoint"
# SPECIFICATIONS is set dynamically in __init__ based on ASR_ENABLE_CHUNKING config
# Default values here for class-level reference (overridden per-instance)
SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=None,
max_duration_seconds=None,
handles_chunking_internally=True,
)
def __init__(self, config: Dict[str, Any]):
"""
Initialize the ASR Endpoint connector.
Args:
config: Configuration dict with keys:
- base_url: ASR service base URL (required)
- timeout: Request timeout in seconds (default: 1800)
- return_speaker_embeddings: Whether to request embeddings (default: False)
- diarize: Whether to enable diarization by default (default: True)
"""
super().__init__(config)
self.base_url = config['base_url'].rstrip('/')
self._config_timeout = config.get('timeout', 1800) # 30 minutes default
self.return_embeddings = config.get('return_speaker_embeddings', False)
self.default_diarize = config.get('diarize', True)
# Configure chunking behavior based on environment variables
# ASR_ENABLE_CHUNKING=true enables app-level chunking for self-hosted ASR services
# that may crash on long files due to GPU memory exhaustion
if ASR_ENABLE_CHUNKING:
# Calculate recommended chunk size (80% of max for safety margin)
recommended_chunk = int(ASR_MAX_DURATION_SECONDS * 0.8)
self.SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=None, # No file size limit
max_duration_seconds=ASR_MAX_DURATION_SECONDS,
handles_chunking_internally=False, # App handles chunking
recommended_chunk_seconds=recommended_chunk,
)
logger.info(
f"ASR chunking enabled: max_duration={ASR_MAX_DURATION_SECONDS}s, "
f"recommended_chunk={recommended_chunk}s"
)
else:
# Default behavior: ASR service handles everything internally
self.SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=None,
max_duration_seconds=None,
handles_chunking_internally=True,
)
# Add speaker embeddings capability if enabled
if self.return_embeddings:
self.CAPABILITIES = self.CAPABILITIES | {TranscriptionCapability.SPEAKER_EMBEDDINGS}
@property
def timeout(self):
"""Get ASR timeout, reading fresh from env/DB each time to respect runtime changes."""
# Environment variables take priority
env_timeout = os.environ.get('ASR_TIMEOUT') or os.environ.get('asr_timeout_seconds')
if env_timeout:
try:
return int(env_timeout)
except (ValueError, TypeError):
pass
# Try database setting (Admin UI)
try:
from src.models import SystemSetting
db_timeout = SystemSetting.get_setting('asr_timeout_seconds', None)
if db_timeout is not None:
return int(db_timeout)
except Exception:
pass
# Fall back to config value from initialization
return self._config_timeout
def _validate_config(self) -> None:
"""Validate required configuration."""
if not self.config.get('base_url'):
raise ConfigurationError("base_url is required for ASR endpoint connector")
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
"""
Transcribe audio using ASR webservice.
Args:
request: Standardized transcription request
Returns:
TranscriptionResponse with segments and speaker information
"""
try:
url = f"{self.base_url}/asr"
params = {
'encode': True,
'task': 'transcribe',
'output': 'json'
}
if request.language:
params['language'] = request.language
logger.info(f"Using transcription language: {request.language}")
# Determine if we should diarize
should_diarize = request.diarize if request.diarize is not None else self.default_diarize
# Send both parameter names for compatibility:
# - 'diarize' is used by whisper-asr-webservice
# - 'enable_diarization' is used by WhisperX
params['diarize'] = should_diarize
params['enable_diarization'] = should_diarize
if should_diarize and self.return_embeddings:
params['return_speaker_embeddings'] = True
if request.min_speakers:
params['min_speakers'] = request.min_speakers
if request.max_speakers:
params['max_speakers'] = request.max_speakers
if request.prompt:
params['initial_prompt'] = request.prompt
if request.hotwords:
params['hotwords'] = request.hotwords
content_type = request.mime_type or 'application/octet-stream'
files = {
'audio_file': (request.filename, request.audio_file, content_type)
}
# Configure timeout: generous values for large file uploads
# Write timeout needs to be high too - large files take time to upload
timeout = httpx.Timeout(
None,
connect=60.0,
read=float(self.timeout),
write=float(self.timeout),
pool=None
)
logger.info(f"Sending ASR request to {url} with params: {params} (timeout: {self.timeout}s)")
with httpx.Client() as client:
response = client.post(url, params=params, files=files, timeout=timeout)
logger.info(f"ASR request completed with status: {response.status_code}")
response.raise_for_status()
# Parse the JSON response
response_text = response.text
try:
data = response.json()
except Exception as json_err:
if response_text.strip().startswith('<'):
logger.error(f"ASR returned HTML error page (status {response.status_code})")
raise ProviderError(
f"ASR service returned HTML error page",
provider=self.PROVIDER_NAME,
status_code=response.status_code
)
else:
raise ProviderError(
f"ASR service returned invalid response: {json_err}",
provider=self.PROVIDER_NAME,
status_code=response.status_code
)
return self._parse_response(data)
except httpx.HTTPStatusError as e:
logger.error(f"ASR request failed with status {e.response.status_code}")
raise ProviderError(
f"ASR request failed with status {e.response.status_code}",
provider=self.PROVIDER_NAME,
status_code=e.response.status_code
) from e
except httpx.TimeoutException as e:
logger.error(f"ASR request timed out after {self.timeout}s")
raise TranscriptionError(f"ASR request timed out after {self.timeout}s") from e
except Exception as e:
error_msg = str(e)
logger.error(f"ASR transcription failed: {error_msg}")
raise TranscriptionError(f"ASR transcription failed: {error_msg}") from e
def _parse_response(self, data: Dict[str, Any]) -> TranscriptionResponse:
"""
Parse ASR webservice response into standardized format.
The ASR response contains:
- text: Full transcription text
- language: Detected language
- segments: Array of segments with speaker, text, start, end
- speaker_embeddings: Optional speaker embeddings (WhisperX only)
"""
segments = []
speakers = set()
full_text_parts = []
last_speaker = None
logger.info(f"ASR response keys: {list(data.keys())}")
if 'segments' in data and isinstance(data['segments'], list):
logger.info(f"Number of segments: {len(data['segments'])}")
for seg in data['segments']:
speaker = seg.get('speaker')
# Handle missing speakers by carrying forward from previous segment
if speaker is None:
if last_speaker is not None:
speaker = last_speaker
else:
speaker = 'UNKNOWN_SPEAKER'
else:
last_speaker = speaker
text = seg.get('text', '').strip()
speakers.add(speaker)
full_text_parts.append(f"[{speaker}]: {text}")
segments.append(TranscriptionSegment(
text=text,
speaker=speaker,
start_time=seg.get('start'),
end_time=seg.get('end')
))
# Get the full text
if 'text' in data and isinstance(data['text'], str):
full_text = data['text']
elif full_text_parts:
full_text = '\n'.join(full_text_parts)
else:
full_text = ''
# Extract speaker embeddings if present
speaker_embeddings = data.get('speaker_embeddings')
if speaker_embeddings:
logger.info(f"Received speaker embeddings for speakers: {list(speaker_embeddings.keys())}")
logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}")
return TranscriptionResponse(
text=full_text,
segments=segments,
speakers=sorted(list(speakers)),
speaker_embeddings=speaker_embeddings,
language=data.get('language'),
provider=self.PROVIDER_NAME,
model="asr-endpoint",
raw_response=data
)
def health_check(self) -> bool:
"""Check if ASR endpoint is reachable."""
try:
with httpx.Client(timeout=10.0) as client:
# Try common health check endpoints
for endpoint in ['/health', '/']:
try:
response = client.get(f"{self.base_url}{endpoint}")
if response.status_code < 500:
return True
except Exception:
continue
return False
except Exception:
return False
@classmethod
def get_config_schema(cls) -> Dict[str, Any]:
"""Return JSON schema for configuration."""
return {
"type": "object",
"required": ["base_url"],
"properties": {
"base_url": {
"type": "string",
"description": "ASR service base URL (e.g., http://whisper-asr:9000)"
},
"timeout": {
"type": "integer",
"default": 1800,
"description": "Request timeout in seconds"
},
"diarize": {
"type": "boolean",
"default": True,
"description": "Enable speaker diarization by default"
},
"return_speaker_embeddings": {
"type": "boolean",
"default": False,
"description": "Request speaker embeddings (WhisperX only)"
}
}
}

View File

@@ -0,0 +1,370 @@
"""
Azure OpenAI Transcribe connector.
Supports Azure OpenAI audio transcription models:
- whisper-1: Basic transcription (no diarization)
- gpt-4o-transcribe: High quality transcription
- gpt-4o-mini-transcribe: Cost-effective transcription
- gpt-4o-transcribe-diarize: Speaker diarization with labels A, B, C, D
Azure OpenAI uses a different API format than standard OpenAI:
- Endpoint: https://{resource}.openai.azure.com/openai/deployments/{deployment}/audio/transcriptions
- Requires api-version query parameter
- Uses api-key header for authentication
"""
import logging
import httpx
from typing import Dict, Any, Set, Optional
from ..base import (
BaseTranscriptionConnector,
TranscriptionCapability,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionSegment,
ConnectorSpecifications,
)
from ..exceptions import TranscriptionError, ConfigurationError
logger = logging.getLogger(__name__)
class AzureOpenAITranscribeConnector(BaseTranscriptionConnector):
"""Connector for Azure OpenAI audio transcription models."""
# Base capabilities - diarization added dynamically based on model
CAPABILITIES: Set[TranscriptionCapability] = {
TranscriptionCapability.TIMESTAMPS,
TranscriptionCapability.LANGUAGE_DETECTION,
}
PROVIDER_NAME = "azure_openai_transcribe"
# Default specifications (will be overridden per-model in __init__)
SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=25 * 1024 * 1024, # 25MB
max_duration_seconds=1400, # Default to most restrictive (diarize model)
min_duration_for_chunking=30,
handles_chunking_internally=False,
requires_chunking_param=True,
recommended_chunk_seconds=1200,
unsupported_codecs=frozenset({'opus'}),
)
# Models and their capabilities
MODELS = {
'whisper-1': {
'supports_diarization': False,
'max_duration_seconds': 1500,
'recommended_chunk_seconds': 1200,
'description': 'OpenAI Whisper model on Azure'
},
'gpt-4o-transcribe': {
'supports_diarization': False,
'max_duration_seconds': 1500,
'recommended_chunk_seconds': 1200,
'description': 'High quality transcription'
},
'gpt-4o-mini-transcribe': {
'supports_diarization': False,
'max_duration_seconds': 1500,
'recommended_chunk_seconds': 1200,
'description': 'Cost-effective transcription'
},
'gpt-4o-mini-transcribe-2025-12-15': {
'supports_diarization': False,
'max_duration_seconds': 1500,
'recommended_chunk_seconds': 1200,
'description': 'Cost-effective transcription (dated version)'
},
'gpt-4o-transcribe-diarize': {
'supports_diarization': True,
'max_duration_seconds': 1400,
'recommended_chunk_seconds': 1200,
'description': 'Speaker diarization with labels A, B, C, D'
}
}
# Default API version - can be overridden in config
DEFAULT_API_VERSION = "2025-04-01-preview"
def __init__(self, config: Dict[str, Any]):
"""
Initialize the Azure OpenAI Transcribe connector.
Args:
config: Configuration dict with keys:
- api_key: Azure OpenAI API key (required)
- endpoint: Azure OpenAI endpoint URL (required)
e.g., https://your-resource.openai.azure.com
- deployment_name: The deployment name for the model (required)
- api_version: API version (default: 2025-04-01-preview)
- model: Model name for validation (optional, defaults to deployment_name)
"""
# Store model/deployment before calling super().__init__
self.deployment_name = config.get('deployment_name', '')
self.model = config.get('model', self.deployment_name)
self.api_version = config.get('api_version', self.DEFAULT_API_VERSION)
# Set model-specific specifications
model_info = self.MODELS.get(self.model, {})
if model_info:
self.SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=25 * 1024 * 1024,
max_duration_seconds=model_info.get('max_duration_seconds', 1400),
min_duration_for_chunking=30,
handles_chunking_internally=False,
requires_chunking_param=True,
recommended_chunk_seconds=model_info.get('recommended_chunk_seconds', 1200),
unsupported_codecs=frozenset({'opus'}),
)
super().__init__(config)
# Parse endpoint URL
self.endpoint = config['endpoint'].rstrip('/')
# Set up HTTP client
self.http_client = httpx.Client(
timeout=httpx.Timeout(
connect=60.0,
read=1800.0, # 30 minutes for long transcriptions
write=1800.0,
pool=None
),
headers={
"api-key": config['api_key'],
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
}
)
# Dynamically update capabilities based on model
if self._model_supports_diarization():
self.CAPABILITIES = self.CAPABILITIES | {
TranscriptionCapability.DIARIZATION,
TranscriptionCapability.KNOWN_SPEAKERS
}
def _validate_config(self) -> None:
"""Validate required configuration."""
if not self.config.get('api_key'):
raise ConfigurationError("api_key is required for Azure OpenAI Transcribe connector")
if not self.config.get('endpoint'):
raise ConfigurationError("endpoint is required for Azure OpenAI Transcribe connector")
if not self.config.get('deployment_name'):
raise ConfigurationError("deployment_name is required for Azure OpenAI Transcribe connector")
def _model_supports_diarization(self) -> bool:
"""Check if the current model supports diarization."""
model_info = self.MODELS.get(self.model, {})
return model_info.get('supports_diarization', False)
def _build_url(self) -> str:
"""Build the Azure OpenAI transcription API URL."""
return f"{self.endpoint}/openai/deployments/{self.deployment_name}/audio/transcriptions?api-version={self.api_version}"
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
"""
Transcribe audio using Azure OpenAI API.
Args:
request: Standardized transcription request
Returns:
TranscriptionResponse, with segments if using diarization model
"""
try:
url = self._build_url()
# Build form data
data = {}
if request.language:
data["language"] = request.language
logger.info(f"Using transcription language: {request.language}")
# Handle diarization model specifics
is_diarize_model = 'diarize' in self.model.lower()
if is_diarize_model:
# Required: chunking_strategy for audio > 30 seconds
data["chunking_strategy"] = "auto"
if request.diarize:
data["response_format"] = "diarized_json"
logger.info("Using diarized_json response format for speaker diarization")
# Known speaker support
if request.known_speaker_names and request.known_speaker_references:
for i, name in enumerate(request.known_speaker_names):
if name in request.known_speaker_references:
data[f"known_speaker_names[{i}]"] = name
data[f"known_speaker_references[{i}]"] = request.known_speaker_references[name]
logger.info(f"Using known speaker references for {len(request.known_speaker_names)} speakers")
else:
# Non-diarization models - request verbose_json for timestamps
data["response_format"] = "verbose_json"
# Combine initial prompt and hotwords into a single prompt
prompt_parts = []
if request.prompt:
prompt_parts.append(request.prompt)
if request.hotwords:
prompt_parts.append(request.hotwords)
if prompt_parts:
data["prompt"] = ". ".join(prompt_parts)
# Prepare file for upload
content_type = request.mime_type or 'application/octet-stream'
files = {
"file": (request.filename, request.audio_file, content_type)
}
logger.info(f"Sending request to Azure OpenAI: {url}")
logger.info(f"Model: {self.model}, Deployment: {self.deployment_name}")
response = self.http_client.post(url, data=data, files=files)
if response.status_code != 200:
error_detail = response.text
try:
error_json = response.json()
if 'error' in error_json:
error_detail = error_json['error'].get('message', error_detail)
except:
pass
logger.error(f"Azure OpenAI transcription failed: {response.status_code} - {error_detail}")
raise TranscriptionError(f"Azure OpenAI transcription failed: {response.status_code} - {error_detail}")
result = response.json()
# Parse response based on format
if is_diarize_model and request.diarize:
return self._parse_diarized_response(result)
else:
return self._parse_response(result)
except TranscriptionError:
raise
except Exception as e:
error_msg = str(e)
logger.error(f"Azure OpenAI transcription failed: {error_msg}")
raise TranscriptionError(f"Azure OpenAI transcription failed: {error_msg}") from e
def _parse_response(self, response: Dict) -> TranscriptionResponse:
"""Parse a standard (non-diarized) response."""
text = response.get('text', '')
# Check for segments (verbose_json format)
segments = []
if 'segments' in response:
for seg in response['segments']:
segments.append(TranscriptionSegment(
text=seg.get('text', ''),
start_time=seg.get('start'),
end_time=seg.get('end')
))
return TranscriptionResponse(
text=text,
segments=segments if segments else None,
language=response.get('language'),
provider=self.PROVIDER_NAME,
model=self.model,
raw_response=response
)
def _parse_diarized_response(self, response: Dict) -> TranscriptionResponse:
"""
Parse diarized JSON response into standardized format.
The diarized_json response contains segments with:
- speaker: "A", "B", "C", "D" etc.
- text: The transcribed text
- start: Segment start time
- end: Segment end time
"""
segments = []
speakers = set()
full_text_parts = []
raw_segments = response.get('segments', [])
if not raw_segments:
# Fallback to text-only response
logger.warning("No segments found in diarized response, falling back to text")
return self._parse_response(response)
for seg in raw_segments:
speaker = seg.get('speaker', 'Unknown')
text = seg.get('text', '')
start = seg.get('start')
end = seg.get('end')
# Skip empty segments
if not text or not text.strip():
continue
speakers.add(speaker)
full_text_parts.append(f"[{speaker}]: {text}")
segments.append(TranscriptionSegment(
text=text,
speaker=speaker,
start_time=start,
end_time=end
))
# Build full text with speaker labels
full_text = '\n'.join(full_text_parts)
logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}")
return TranscriptionResponse(
text=full_text,
segments=segments,
speakers=sorted(list(speakers)),
language=response.get('language'),
provider=self.PROVIDER_NAME,
model=self.model,
raw_response=response
)
def health_check(self) -> bool:
"""Check if the connector is properly configured."""
return bool(
self.config.get('api_key') and
self.config.get('endpoint') and
self.config.get('deployment_name')
)
@classmethod
def get_config_schema(cls) -> Dict[str, Any]:
"""Return JSON schema for configuration."""
return {
"type": "object",
"required": ["api_key", "endpoint", "deployment_name"],
"properties": {
"api_key": {
"type": "string",
"description": "Azure OpenAI API key"
},
"endpoint": {
"type": "string",
"description": "Azure OpenAI endpoint URL (e.g., https://your-resource.openai.azure.com)"
},
"deployment_name": {
"type": "string",
"description": "The deployment name for your transcription model"
},
"api_version": {
"type": "string",
"default": cls.DEFAULT_API_VERSION,
"description": "Azure OpenAI API version"
},
"model": {
"type": "string",
"enum": list(cls.MODELS.keys()),
"description": "Model type (for capability detection, defaults to deployment_name)"
}
}
}

View File

@@ -0,0 +1,329 @@
"""
OpenAI GPT-4o Transcribe connector.
Supports the newer GPT-4o based transcription models:
- gpt-4o-transcribe: High quality transcription
- gpt-4o-mini-transcribe: Cost-effective transcription
- gpt-4o-transcribe-diarize: Speaker diarization with labels A, B, C, D
"""
import logging
import httpx
from openai import OpenAI
from typing import Dict, Any, Set, Optional
from ..base import (
BaseTranscriptionConnector,
TranscriptionCapability,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionSegment,
ConnectorSpecifications,
)
from ..exceptions import TranscriptionError, ConfigurationError
logger = logging.getLogger(__name__)
class OpenAITranscribeConnector(BaseTranscriptionConnector):
"""Connector for GPT-4o Transcribe models with optional diarization support."""
# Base capabilities - diarization added dynamically based on model
CAPABILITIES: Set[TranscriptionCapability] = {
TranscriptionCapability.TIMESTAMPS,
TranscriptionCapability.LANGUAGE_DETECTION,
}
PROVIDER_NAME = "openai_transcribe"
# GPT-4o Transcribe models have specific constraints
# - 25MB file size limit (all models)
# - Duration limits vary by model:
# - gpt-4o-transcribe / gpt-4o-mini-transcribe: 1500 seconds (25 min)
# - gpt-4o-transcribe-diarize: 1400 seconds (~23 min)
# - chunking_strategy="auto" handles files internally up to the duration limit
# Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, flac, ogg, oga
# NOT supported: opus (used by WhatsApp voice notes, Discord)
# Default specifications (will be overridden per-model in __init__)
SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=25 * 1024 * 1024, # 25MB
max_duration_seconds=1400, # Default to most restrictive (diarize model)
min_duration_for_chunking=30, # >30s needs chunking_strategy param
handles_chunking_internally=False, # App must chunk files > max_duration_seconds
requires_chunking_param=True, # Must send chunking_strategy for >30s
recommended_chunk_seconds=1200, # 20 minutes - safe margin
unsupported_codecs=frozenset({'opus'}), # OpenAI API doesn't support opus
)
# Models and their capabilities with duration limits
MODELS = {
'gpt-4o-transcribe': {
'supports_diarization': False,
'max_duration_seconds': 1500, # 25 minutes
'recommended_chunk_seconds': 1200, # 20 minutes
'description': 'High quality transcription'
},
'gpt-4o-mini-transcribe': {
'supports_diarization': False,
'max_duration_seconds': 1500, # 25 minutes
'recommended_chunk_seconds': 1200, # 20 minutes
'description': 'Cost-effective transcription'
},
'gpt-4o-mini-transcribe-2025-12-15': {
'supports_diarization': False,
'max_duration_seconds': 1500, # 25 minutes
'recommended_chunk_seconds': 1200, # 20 minutes
'description': 'Cost-effective transcription (dated version)'
},
'gpt-4o-transcribe-diarize': {
'supports_diarization': True,
'max_duration_seconds': 1400, # ~23 minutes (more restrictive)
'recommended_chunk_seconds': 1200, # 20 minutes
'description': 'Speaker diarization with labels A, B, C, D'
}
}
def __init__(self, config: Dict[str, Any]):
"""
Initialize the GPT-4o Transcribe connector.
Args:
config: Configuration dict with keys:
- api_key: OpenAI API key (required)
- base_url: API base URL (default: https://api.openai.com/v1)
- model: Model name (required, one of MODELS)
- http_client: Optional httpx.Client instance
"""
# Store model before calling super().__init__ since _validate_config needs it
self.model = config.get('model', 'gpt-4o-transcribe')
# Set model-specific specifications (override class defaults)
# Use SPECIFICATIONS (uppercase) to shadow the class attribute
model_info = self.MODELS.get(self.model, {})
self.SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=25 * 1024 * 1024, # 25MB (same for all)
max_duration_seconds=model_info.get('max_duration_seconds', 1400),
min_duration_for_chunking=30,
handles_chunking_internally=False,
requires_chunking_param=True,
recommended_chunk_seconds=model_info.get('recommended_chunk_seconds', 1200),
unsupported_codecs=frozenset({'opus'}),
)
super().__init__(config)
# Set up HTTP client with custom headers
http_client = config.get('http_client')
if not http_client:
app_headers = {
"HTTP-Referer": "https://github.com/murtaza-nasir/speakr",
"X-Title": "Speakr - AI Audio Transcription",
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
}
http_client = httpx.Client(verify=True, headers=app_headers)
self.client = OpenAI(
api_key=config['api_key'],
base_url=config.get('base_url', 'https://api.openai.com/v1'),
http_client=http_client
)
# Dynamically update capabilities based on model
if self._model_supports_diarization():
self.CAPABILITIES = self.CAPABILITIES | {
TranscriptionCapability.DIARIZATION,
TranscriptionCapability.KNOWN_SPEAKERS
}
def _validate_config(self) -> None:
"""Validate required configuration."""
if not self.config.get('api_key'):
raise ConfigurationError("api_key is required for OpenAI Transcribe connector")
model = self.config.get('model', 'gpt-4o-transcribe')
if model not in self.MODELS:
raise ConfigurationError(
f"Unknown model: {model}. Valid models: {list(self.MODELS.keys())}"
)
def _model_supports_diarization(self) -> bool:
"""Check if the current model supports diarization."""
model_info = self.MODELS.get(self.model, {})
return model_info.get('supports_diarization', False)
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
"""
Transcribe audio using GPT-4o Transcribe API.
Args:
request: Standardized transcription request
Returns:
TranscriptionResponse, with segments if using diarization model
"""
try:
params = {
"model": self.model,
"file": request.audio_file,
}
if request.language:
params["language"] = request.language
logger.info(f"Using transcription language: {request.language}")
# Handle diarization model specifics
if self.model == 'gpt-4o-transcribe-diarize':
# Required: chunking_strategy for audio > 30 seconds
params["chunking_strategy"] = "auto"
if request.diarize:
params["response_format"] = "diarized_json"
logger.info("Using diarized_json response format for speaker diarization")
# Known speaker support for maintaining speaker identity across chunks
# known_speaker_names is a list of speaker labels (e.g., ["A", "B"])
# known_speaker_references is a dict mapping label to data URL
if request.known_speaker_names and request.known_speaker_references:
# OpenAI expects lists for both parameters
speaker_names = []
speaker_refs = []
for name in request.known_speaker_names:
if name in request.known_speaker_references:
speaker_names.append(name)
speaker_refs.append(request.known_speaker_references[name])
if speaker_names:
# Use extra_body to pass the known speaker parameters
params["extra_body"] = {
"known_speaker_names": speaker_names,
"known_speaker_references": speaker_refs
}
logger.info(f"Using known speaker references for {len(speaker_names)} speakers: {speaker_names}")
else:
# Non-diarization models - combine initial prompt and hotwords
prompt_parts = []
if request.prompt:
prompt_parts.append(request.prompt)
if request.hotwords:
prompt_parts.append(request.hotwords)
if prompt_parts:
params["prompt"] = ". ".join(prompt_parts)
logger.info(f"Sending request to GPT-4o Transcribe API with model: {self.model}")
response = self.client.audio.transcriptions.create(**params)
# Parse response based on format
if self.model == 'gpt-4o-transcribe-diarize' and request.diarize:
return self._parse_diarized_response(response)
else:
return self._parse_text_response(response)
except Exception as e:
error_msg = str(e)
logger.error(f"GPT-4o transcription failed: {error_msg}")
raise TranscriptionError(f"GPT-4o transcription failed: {error_msg}") from e
def _parse_text_response(self, response) -> TranscriptionResponse:
"""Parse a plain text response."""
text = response.text if hasattr(response, 'text') else str(response)
return TranscriptionResponse(
text=text,
provider=self.PROVIDER_NAME,
model=self.model
)
def _parse_diarized_response(self, response) -> TranscriptionResponse:
"""
Parse diarized JSON response into standardized format.
The diarized_json response contains segments with:
- speaker: "A", "B", "C", "D" etc.
- text: The transcribed text
- start: Segment start time
- end: Segment end time
"""
segments = []
speakers = set()
full_text_parts = []
# Handle response object - could be dict or object with attributes
if hasattr(response, 'segments'):
raw_segments = response.segments
elif isinstance(response, dict) and 'segments' in response:
raw_segments = response['segments']
else:
# Fallback to text-only response
logger.warning("No segments found in diarized response, falling back to text")
return self._parse_text_response(response)
for seg in raw_segments:
# Handle both dict and object segments
if isinstance(seg, dict):
speaker = seg.get('speaker', 'Unknown')
text = seg.get('text', '')
start = seg.get('start')
end = seg.get('end')
else:
speaker = getattr(seg, 'speaker', 'Unknown')
text = getattr(seg, 'text', '')
start = getattr(seg, 'start', None)
end = getattr(seg, 'end', None)
# Skip empty segments
if not text or not text.strip():
continue
speakers.add(speaker)
full_text_parts.append(f"[{speaker}]: {text}")
segments.append(TranscriptionSegment(
text=text,
speaker=speaker,
start_time=start,
end_time=end
))
# Always use our formatted text with speaker labels for diarized responses
# OpenAI's response.text is plain text WITHOUT speaker labels
full_text = '\n'.join(full_text_parts)
logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}")
return TranscriptionResponse(
text=full_text,
segments=segments,
speakers=sorted(list(speakers)),
provider=self.PROVIDER_NAME,
model=self.model,
raw_response=response if isinstance(response, dict) else None
)
def health_check(self) -> bool:
"""Check if the connector is properly configured."""
return bool(self.config.get('api_key'))
@classmethod
def get_config_schema(cls) -> Dict[str, Any]:
"""Return JSON schema for configuration."""
return {
"type": "object",
"required": ["api_key"],
"properties": {
"api_key": {
"type": "string",
"description": "OpenAI API key"
},
"base_url": {
"type": "string",
"default": "https://api.openai.com/v1",
"description": "API base URL"
},
"model": {
"type": "string",
"enum": list(cls.MODELS.keys()),
"default": "gpt-4o-transcribe",
"description": "GPT-4o transcription model to use"
}
}
}

View File

@@ -0,0 +1,153 @@
"""
OpenAI Whisper API connector (whisper-1 model).
This is the legacy Whisper API connector that supports the whisper-1 model.
It returns plain text transcriptions without speaker diarization.
"""
import logging
import os
import httpx
from openai import OpenAI
from typing import Dict, Any, Set
from ..base import (
BaseTranscriptionConnector,
TranscriptionCapability,
TranscriptionRequest,
TranscriptionResponse,
ConnectorSpecifications,
)
from ..exceptions import TranscriptionError, ConfigurationError
logger = logging.getLogger(__name__)
class OpenAIWhisperConnector(BaseTranscriptionConnector):
"""Connector for OpenAI Whisper API (whisper-1 model)."""
CAPABILITIES: Set[TranscriptionCapability] = {
TranscriptionCapability.CHUNKING,
TranscriptionCapability.TIMESTAMPS,
TranscriptionCapability.LANGUAGE_DETECTION,
}
PROVIDER_NAME = "openai_whisper"
# OpenAI Whisper has a 25MB file limit and doesn't handle chunking internally
# Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, flac, ogg, oga
# NOT supported: opus (used by WhatsApp voice notes, Discord)
SPECIFICATIONS = ConnectorSpecifications(
max_file_size_bytes=25 * 1024 * 1024, # 25MB
handles_chunking_internally=False,
recommended_chunk_seconds=600, # 10 minutes
unsupported_codecs=frozenset({'opus'}), # OpenAI API doesn't support opus
)
def __init__(self, config: Dict[str, Any]):
"""
Initialize the Whisper connector.
Args:
config: Configuration dict with keys:
- api_key: OpenAI API key (required)
- base_url: API base URL (optional)
- model: Model name (default: whisper-1)
- http_client: Optional httpx.Client instance
"""
super().__init__(config)
# Set up HTTP client with custom headers
http_client = config.get('http_client')
if not http_client:
app_headers = {
"HTTP-Referer": "https://github.com/murtaza-nasir/speakr",
"X-Title": "Speakr - AI Audio Transcription",
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
}
http_client = httpx.Client(verify=True, headers=app_headers)
self.client = OpenAI(
api_key=config['api_key'],
base_url=config.get('base_url') or None,
http_client=http_client
)
self.model = config.get('model', 'whisper-1')
def _validate_config(self) -> None:
"""Validate required configuration."""
if not self.config.get('api_key'):
raise ConfigurationError("api_key is required for OpenAI Whisper connector")
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
"""
Transcribe audio using OpenAI Whisper API.
Args:
request: Standardized transcription request
Returns:
TranscriptionResponse with plain text (no diarization)
"""
try:
params = {
"model": self.model,
"file": request.audio_file,
}
if request.language:
params["language"] = request.language
logger.info(f"Using transcription language: {request.language}")
# Combine initial prompt and hotwords into a single prompt
# OpenAI Whisper uses prompt for both steering and vocabulary hints
prompt_parts = []
if request.prompt:
prompt_parts.append(request.prompt)
if request.hotwords:
prompt_parts.append(request.hotwords)
if prompt_parts:
params["prompt"] = ". ".join(prompt_parts)
if request.temperature is not None:
params["temperature"] = request.temperature
logger.info(f"Sending request to Whisper API with model: {self.model}")
transcript = self.client.audio.transcriptions.create(**params)
return TranscriptionResponse(
text=transcript.text,
provider=self.PROVIDER_NAME,
model=self.model
)
except Exception as e:
error_msg = str(e)
logger.error(f"Whisper transcription failed: {error_msg}")
raise TranscriptionError(f"Whisper transcription failed: {error_msg}") from e
def health_check(self) -> bool:
"""Check if the connector is properly configured."""
return bool(self.config.get('api_key'))
@classmethod
def get_config_schema(cls) -> Dict[str, Any]:
"""Return JSON schema for configuration."""
return {
"type": "object",
"required": ["api_key"],
"properties": {
"api_key": {
"type": "string",
"description": "OpenAI API key"
},
"base_url": {
"type": "string",
"description": "API base URL (optional, for OpenAI-compatible endpoints)"
},
"model": {
"type": "string",
"default": "whisper-1",
"description": "Whisper model to use"
}
}
}

View File

@@ -0,0 +1,32 @@
"""
Custom exceptions for transcription services.
"""
class TranscriptionError(Exception):
"""Base exception for transcription errors."""
pass
class ConfigurationError(TranscriptionError):
"""Configuration-related errors (missing or invalid config)."""
pass
class ProviderError(TranscriptionError):
"""Provider/API errors."""
def __init__(self, message: str, provider: str = None, status_code: int = None):
super().__init__(message)
self.provider = provider
self.status_code = status_code
class AudioFormatError(TranscriptionError):
"""Unsupported audio format errors."""
pass
class ChunkingError(TranscriptionError):
"""Errors during file chunking."""
pass

View File

@@ -0,0 +1,353 @@
"""
Connector registry for managing transcription connectors.
Provides factory pattern for creating and managing transcription connectors,
with auto-detection from environment variables for backwards compatibility.
"""
import os
import logging
from typing import Dict, Any, Optional, Type, List
from .base import BaseTranscriptionConnector, TranscriptionCapability, TranscriptionRequest, TranscriptionResponse
from .exceptions import ConfigurationError
logger = logging.getLogger(__name__)
class ConnectorRegistry:
"""
Registry for managing transcription connectors.
Singleton pattern - use get_registry() to get the shared instance.
"""
_instance = None
_connectors: Dict[str, Type[BaseTranscriptionConnector]] = {}
_active_connector: Optional[BaseTranscriptionConnector] = None
_connector_name: str = ""
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if self._initialized:
return
self._register_builtin_connectors()
self._initialized = True
def _register_builtin_connectors(self):
"""Register all built-in connectors."""
from .connectors.openai_whisper import OpenAIWhisperConnector
from .connectors.openai_transcribe import OpenAITranscribeConnector
from .connectors.asr_endpoint import ASREndpointConnector
from .connectors.azure_openai_transcribe import AzureOpenAITranscribeConnector
self.register('openai_whisper', OpenAIWhisperConnector)
self.register('openai_transcribe', OpenAITranscribeConnector)
self.register('asr_endpoint', ASREndpointConnector)
self.register('azure_openai_transcribe', AzureOpenAITranscribeConnector)
def register(self, name: str, connector_class: Type[BaseTranscriptionConnector]):
"""
Register a connector class.
Args:
name: Unique name for the connector
connector_class: The connector class to register
"""
self._connectors[name] = connector_class
logger.debug(f"Registered transcription connector: {name}")
def get_connector_class(self, name: str) -> Type[BaseTranscriptionConnector]:
"""
Get a connector class by name.
Args:
name: The connector name
Returns:
The connector class
Raises:
ConfigurationError: If connector not found
"""
if name not in self._connectors:
raise ConfigurationError(
f"Unknown connector: {name}. Available: {list(self._connectors.keys())}"
)
return self._connectors[name]
def create_connector(self, name: str, config: Dict[str, Any]) -> BaseTranscriptionConnector:
"""
Create a connector instance.
Args:
name: The connector name
config: Configuration dict for the connector
Returns:
Configured connector instance
"""
connector_class = self.get_connector_class(name)
return connector_class(config)
def list_connectors(self) -> List[Dict[str, Any]]:
"""
List all registered connectors with their capabilities.
Returns:
List of connector info dicts
"""
result = []
for name, cls in self._connectors.items():
result.append({
'name': name,
'provider_name': cls.PROVIDER_NAME,
'capabilities': [c.name for c in cls.CAPABILITIES],
'config_schema': cls.get_config_schema()
})
return result
def initialize_from_env(self) -> BaseTranscriptionConnector:
"""
Initialize the active connector from environment variables.
Auto-detection priority:
1. TRANSCRIPTION_CONNECTOR - explicit connector name
2. ASR_BASE_URL is set - use ASR endpoint (smarter detection)
- USE_ASR_ENDPOINT=true also works (backwards compat, with deprecation warning)
3. TRANSCRIPTION_MODEL contains 'gpt-4o' - use OpenAI Transcribe
4. TRANSCRIPTION_MODEL is set - use OpenAI Whisper with that model
5. Default to OpenAI Whisper (whisper-1)
Returns:
The initialized connector
"""
connector_name = os.environ.get('TRANSCRIPTION_CONNECTOR', '').lower().strip()
if not connector_name:
# Auto-detect based on existing config for backwards compatibility
asr_base_url = os.environ.get('ASR_BASE_URL', '').strip()
use_asr_flag = os.environ.get('USE_ASR_ENDPOINT', 'false').lower() == 'true'
transcription_model = os.environ.get('TRANSCRIPTION_MODEL', '').lower()
whisper_model = os.environ.get('WHISPER_MODEL', '').lower()
# Deprecation warning for legacy USE_ASR_ENDPOINT flag
if use_asr_flag:
logger.warning(
"USE_ASR_ENDPOINT=true is deprecated. "
"Set ASR_BASE_URL instead for auto-detection, or use TRANSCRIPTION_CONNECTOR=asr_endpoint"
)
# Priority 2: ASR endpoint - check ASR_BASE_URL or legacy flag
if asr_base_url or use_asr_flag:
connector_name = 'asr_endpoint'
if asr_base_url:
logger.info("Auto-detected ASR endpoint from ASR_BASE_URL")
# Priority 2.5: Azure OpenAI - check for Azure endpoint URL
elif self._is_azure_endpoint():
connector_name = 'azure_openai_transcribe'
logger.info("Auto-detected Azure OpenAI from TRANSCRIPTION_BASE_URL")
# Priority 3: Model-based detection
elif transcription_model and 'gpt-4o' in transcription_model:
connector_name = 'openai_transcribe'
logger.info(f"Auto-detected OpenAI Transcribe from TRANSCRIPTION_MODEL={transcription_model}")
# Priority 4 & 5: OpenAI Whisper (with custom or default model)
else:
connector_name = 'openai_whisper'
model = transcription_model or whisper_model or 'whisper-1'
logger.info(f"Using OpenAI Whisper connector with model: {model}")
config = self._build_config_from_env(connector_name)
try:
self._active_connector = self.create_connector(connector_name, config)
self._connector_name = connector_name
logger.info(f"Initialized transcription connector: {connector_name}")
logger.info(f"Capabilities: {[c.name for c in self._active_connector.get_capabilities()]}")
return self._active_connector
except Exception as e:
logger.error(f"Failed to initialize connector '{connector_name}': {e}")
raise ConfigurationError(f"Failed to initialize connector '{connector_name}': {e}") from e
def _get_asr_timeout(self) -> int:
"""
Get ASR timeout with fallback chain: ENV -> Admin UI -> default.
Priority:
1. ASR_TIMEOUT environment variable
2. asr_timeout_seconds environment variable (legacy)
3. SystemSetting from Admin UI (database)
4. Default: 1800 seconds (30 minutes)
"""
# Check environment variables first
env_timeout = os.environ.get('ASR_TIMEOUT') or os.environ.get('asr_timeout_seconds')
if env_timeout:
return int(env_timeout)
# Fall back to Admin UI setting (SystemSetting in database)
try:
from src.models import SystemSetting
db_timeout = SystemSetting.get_setting('asr_timeout_seconds', None)
if db_timeout is not None:
return int(db_timeout)
except Exception as e:
# May fail if no app context or during initialization
logger.debug(f"Could not read ASR timeout from database: {e}")
# Default: 30 minutes
return 1800
def _is_azure_endpoint(self) -> bool:
"""Check if the TRANSCRIPTION_BASE_URL points to an Azure OpenAI endpoint."""
base_url = os.environ.get('TRANSCRIPTION_BASE_URL', '').lower()
return '.openai.azure.com' in base_url or '.cognitiveservices.azure.com' in base_url
def _build_config_from_env(self, connector_name: str) -> Dict[str, Any]:
"""
Build connector config from environment variables.
Args:
connector_name: The connector to build config for
Returns:
Configuration dict
"""
if connector_name == 'asr_endpoint':
base_url = os.environ.get('ASR_BASE_URL', '')
if base_url:
base_url = base_url.split('#')[0].strip()
return {
'base_url': base_url,
'timeout': self._get_asr_timeout(),
'diarize': os.environ.get('ASR_DIARIZE', 'true').lower() == 'true',
'return_speaker_embeddings': os.environ.get('ASR_RETURN_SPEAKER_EMBEDDINGS', 'false').lower() == 'true'
}
elif connector_name == 'openai_transcribe':
base_url = os.environ.get('TRANSCRIPTION_BASE_URL', 'https://api.openai.com/v1')
if base_url:
base_url = base_url.split('#')[0].strip()
return {
'api_key': os.environ.get('TRANSCRIPTION_API_KEY', ''),
'base_url': base_url,
'model': os.environ.get('TRANSCRIPTION_MODEL', 'gpt-4o-transcribe')
}
elif connector_name == 'azure_openai_transcribe':
# Azure OpenAI requires endpoint and deployment_name
# TRANSCRIPTION_BASE_URL should be the Azure endpoint (e.g., https://your-resource.openai.azure.com)
endpoint = os.environ.get('TRANSCRIPTION_BASE_URL', '')
if endpoint:
endpoint = endpoint.split('#')[0].strip()
# Remove any trailing /openai or /v1 paths - we build the full URL ourselves
endpoint = endpoint.rstrip('/')
for suffix in ['/openai/v1', '/openai', '/v1']:
if endpoint.lower().endswith(suffix):
endpoint = endpoint[:-len(suffix)]
return {
'api_key': os.environ.get('TRANSCRIPTION_API_KEY', ''),
'endpoint': endpoint,
'deployment_name': os.environ.get('AZURE_DEPLOYMENT_NAME', os.environ.get('TRANSCRIPTION_MODEL', 'gpt-4o-transcribe')),
'api_version': os.environ.get('AZURE_API_VERSION', '2025-04-01-preview'),
'model': os.environ.get('TRANSCRIPTION_MODEL', '') # For capability detection
}
else: # openai_whisper (default)
base_url = os.environ.get('TRANSCRIPTION_BASE_URL', '')
if base_url:
base_url = base_url.split('#')[0].strip()
# Support both TRANSCRIPTION_MODEL and legacy WHISPER_MODEL
# TRANSCRIPTION_MODEL takes priority for custom Whisper variants
model = os.environ.get('TRANSCRIPTION_MODEL', '') or os.environ.get('WHISPER_MODEL', 'whisper-1')
return {
'api_key': os.environ.get('TRANSCRIPTION_API_KEY', ''),
'base_url': base_url or None,
'model': model
}
def get_active_connector(self) -> BaseTranscriptionConnector:
"""
Get the currently active connector.
Initializes from environment if not already initialized.
Returns:
The active connector
"""
if not self._active_connector:
self.initialize_from_env()
return self._active_connector
def get_active_connector_name(self) -> str:
"""Get the name of the currently active connector."""
if not self._active_connector:
self.initialize_from_env()
return self._connector_name
def reinitialize(self) -> BaseTranscriptionConnector:
"""
Force re-initialization of the connector.
Useful when environment variables have changed.
Returns:
The newly initialized connector
"""
self._active_connector = None
self._connector_name = ""
return self.initialize_from_env()
# Global registry instance
_registry: Optional[ConnectorRegistry] = None
def get_registry() -> ConnectorRegistry:
"""Get the global connector registry."""
global _registry
if _registry is None:
_registry = ConnectorRegistry()
return _registry
# Convenience aliases
connector_registry = get_registry()
def transcribe(request: TranscriptionRequest) -> TranscriptionResponse:
"""
Transcribe audio using the active connector.
This is a convenience function that uses the global registry.
Args:
request: The transcription request
Returns:
Transcription response
"""
connector = get_registry().get_active_connector()
return connector.transcribe(request)
def get_connector() -> BaseTranscriptionConnector:
"""Get the active transcription connector."""
return get_registry().get_active_connector()
def supports_diarization() -> bool:
"""Check if the active connector supports diarization."""
return get_registry().get_active_connector().supports_diarization

View File

@@ -0,0 +1,312 @@
"""
Transcription usage tracking service for monitoring audio transcription consumption and budget enforcement.
"""
import logging
from datetime import date, datetime, timedelta
from typing import Tuple, Optional, Dict, List
from sqlalchemy import func, extract
from src.database import db
from src.models.transcription_usage import TranscriptionUsage
from src.models.user import User
logger = logging.getLogger(__name__)
# Pricing configuration per connector/model (dollars per minute)
TRANSCRIPTION_PRICING = {
'openai_whisper': {
'whisper-1': 0.006, # $0.006/min
'default': 0.006,
},
'openai_transcribe': {
'gpt-4o-transcribe': 0.006, # $0.006/min
'gpt-4o-mini-transcribe': 0.003, # $0.003/min
'gpt-4o-mini-transcribe-2025-12-15': 0.003,
'gpt-4o-transcribe-diarize': 0.006,
'default': 0.006,
},
'asr_endpoint': {
'default': 0.0, # Self-hosted = free
},
}
def get_transcription_cost_per_minute(connector_type: str, model_name: str = None) -> float:
"""
Get the cost per minute for a given connector and model.
Args:
connector_type: The connector provider name
model_name: The specific model (optional)
Returns:
Cost per minute in dollars
"""
connector_pricing = TRANSCRIPTION_PRICING.get(connector_type, {})
if model_name and model_name in connector_pricing:
return connector_pricing[model_name]
# Fall back to 'default' pricing for the connector
return connector_pricing.get('default', 0.0)
class TranscriptionTracker:
"""Service for recording and checking transcription usage."""
CONNECTOR_TYPES = [
'openai_whisper',
'openai_transcribe',
'asr_endpoint',
]
def record_usage(
self,
user_id: int,
connector_type: str,
audio_duration_seconds: int,
model_name: str = None,
estimated_cost: float = None
):
"""
Record transcription usage - upserts into daily aggregate.
Args:
user_id: User ID who made the request
connector_type: Type of connector (openai_whisper, openai_transcribe, asr_endpoint)
audio_duration_seconds: Duration of audio transcribed in seconds
model_name: Name of the model used
estimated_cost: Pre-calculated cost (if None, calculated from pricing config)
"""
try:
today = date.today()
# Calculate cost if not provided
if estimated_cost is None:
cost_per_minute = get_transcription_cost_per_minute(connector_type, model_name)
estimated_cost = (audio_duration_seconds / 60.0) * cost_per_minute
# Find or create today's record for this user/connector
usage = TranscriptionUsage.query.filter_by(
user_id=user_id,
date=today,
connector_type=connector_type
).first()
if usage:
# Update existing record
usage.audio_duration_seconds += audio_duration_seconds
usage.estimated_cost += estimated_cost
usage.request_count += 1
if model_name:
usage.model_name = model_name # Update to latest model used
else:
# Create new record
usage = TranscriptionUsage(
user_id=user_id,
date=today,
connector_type=connector_type,
audio_duration_seconds=audio_duration_seconds,
request_count=1,
model_name=model_name,
estimated_cost=estimated_cost or 0.0
)
db.session.add(usage)
db.session.commit()
logger.debug(f"Recorded {audio_duration_seconds}s transcription for user {user_id}, connector {connector_type}")
return usage
except Exception as e:
logger.error(f"Failed to record transcription usage: {e}")
db.session.rollback()
return None
def get_monthly_usage(self, user_id: int, year: int = None, month: int = None) -> int:
"""Get total seconds transcribed by a user in a given month."""
if year is None:
year = date.today().year
if month is None:
month = date.today().month
result = db.session.query(func.sum(TranscriptionUsage.audio_duration_seconds)).filter(
TranscriptionUsage.user_id == user_id,
extract('year', TranscriptionUsage.date) == year,
extract('month', TranscriptionUsage.date) == month
).scalar()
return result or 0
def get_monthly_cost(self, user_id: int, year: int = None, month: int = None) -> float:
"""Get total estimated cost for a user in a given month."""
if year is None:
year = date.today().year
if month is None:
month = date.today().month
result = db.session.query(func.sum(TranscriptionUsage.estimated_cost)).filter(
TranscriptionUsage.user_id == user_id,
extract('year', TranscriptionUsage.date) == year,
extract('month', TranscriptionUsage.date) == month
).scalar()
return result or 0.0
def check_budget(self, user_id: int) -> Tuple[bool, float, Optional[str]]:
"""
Check if user is within transcription budget.
Returns:
(can_proceed, usage_percentage, message)
- can_proceed: False if hard cap (100%) reached
- usage_percentage: 0-100+
- message: Warning/error message if applicable
"""
try:
user = db.session.get(User, user_id)
if not user or not user.monthly_transcription_budget:
return (True, 0, None) # No budget = unlimited
current_usage = self.get_monthly_usage(user_id)
budget = user.monthly_transcription_budget
percentage = (current_usage / budget) * 100
if percentage >= 100:
minutes_used = current_usage // 60
minutes_budget = budget // 60
return (False, percentage,
f"Monthly transcription budget exceeded ({minutes_used}/{minutes_budget} minutes). Contact admin for more time.")
elif percentage >= 80:
return (True, percentage,
f"Warning: {percentage:.1f}% of monthly transcription budget used.")
else:
return (True, percentage, None)
except Exception as e:
logger.error(f"Failed to check transcription budget for user {user_id}: {e}")
# Fail open - allow the request if we can't check
return (True, 0, None)
def get_daily_stats(self, days: int = 30, user_id: int = None) -> List[Dict]:
"""Get daily transcription usage for charts."""
start_date = date.today() - timedelta(days=days - 1)
query = db.session.query(
TranscriptionUsage.date,
TranscriptionUsage.connector_type,
func.sum(TranscriptionUsage.audio_duration_seconds).label('seconds'),
func.sum(TranscriptionUsage.estimated_cost).label('cost')
).filter(TranscriptionUsage.date >= start_date)
if user_id:
query = query.filter(TranscriptionUsage.user_id == user_id)
results = query.group_by(TranscriptionUsage.date, TranscriptionUsage.connector_type).all()
# Organize by date
stats = {}
for r in results:
date_str = r.date.isoformat()
if date_str not in stats:
stats[date_str] = {'date': date_str, 'total_seconds': 0, 'total_minutes': 0, 'cost': 0.0, 'by_connector': {}}
stats[date_str]['total_seconds'] += r.seconds or 0
stats[date_str]['total_minutes'] = stats[date_str]['total_seconds'] // 60
stats[date_str]['cost'] += r.cost or 0
stats[date_str]['by_connector'][r.connector_type] = {
'seconds': r.seconds or 0,
'minutes': (r.seconds or 0) // 60
}
# Fill in missing dates with zeros
all_dates = []
current = start_date
while current <= date.today():
date_str = current.isoformat()
if date_str not in stats:
stats[date_str] = {'date': date_str, 'total_seconds': 0, 'total_minutes': 0, 'cost': 0.0, 'by_connector': {}}
all_dates.append(date_str)
current += timedelta(days=1)
return [stats[d] for d in sorted(all_dates)]
def get_monthly_stats(self, months: int = 12) -> List[Dict]:
"""Get monthly transcription usage for charts."""
results = db.session.query(
extract('year', TranscriptionUsage.date).label('year'),
extract('month', TranscriptionUsage.date).label('month'),
func.sum(TranscriptionUsage.audio_duration_seconds).label('seconds'),
func.sum(TranscriptionUsage.estimated_cost).label('cost')
).group_by('year', 'month').order_by('year', 'month').all()
# Get last N months
monthly_data = [
{
'year': int(r.year),
'month': int(r.month),
'seconds': r.seconds or 0,
'minutes': (r.seconds or 0) // 60,
'cost': r.cost or 0
}
for r in results
]
return monthly_data[-months:] if len(monthly_data) > months else monthly_data
def get_user_stats(self) -> List[Dict]:
"""Get per-user transcription usage breakdown for current month."""
today = date.today()
results = db.session.query(
User.id,
User.username,
User.monthly_transcription_budget,
func.sum(TranscriptionUsage.audio_duration_seconds).label('usage'),
func.sum(TranscriptionUsage.estimated_cost).label('cost')
).outerjoin(
TranscriptionUsage,
(User.id == TranscriptionUsage.user_id) &
(extract('year', TranscriptionUsage.date) == today.year) &
(extract('month', TranscriptionUsage.date) == today.month)
).group_by(User.id).all()
return [
{
'user_id': r.id,
'username': r.username,
'monthly_budget_seconds': r.monthly_transcription_budget,
'monthly_budget_minutes': (r.monthly_transcription_budget // 60) if r.monthly_transcription_budget else None,
'current_usage_seconds': r.usage or 0,
'current_usage_minutes': (r.usage or 0) // 60,
'cost': r.cost or 0,
'percentage': ((r.usage or 0) / r.monthly_transcription_budget * 100)
if r.monthly_transcription_budget else 0
}
for r in results
]
def get_today_usage(self, user_id: int = None) -> Dict:
"""Get today's transcription usage."""
today = date.today()
query = db.session.query(
func.sum(TranscriptionUsage.audio_duration_seconds).label('seconds'),
func.sum(TranscriptionUsage.estimated_cost).label('cost')
).filter(TranscriptionUsage.date == today)
if user_id:
query = query.filter(TranscriptionUsage.user_id == user_id)
result = query.first()
return {
'seconds': result.seconds or 0,
'minutes': (result.seconds or 0) // 60,
'cost': result.cost or 0
}
# Singleton instance
transcription_tracker = TranscriptionTracker()