Initial release: DictIA v0.8.14-alpha (fork de Speakr, AGPL-3.0)
This commit is contained in:
34
src/services/__init__.py
Normal file
34
src/services/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Service layer for business logic.
|
||||
"""
|
||||
|
||||
from .embeddings import *
|
||||
from .llm import *
|
||||
from .document import *
|
||||
from .retention import *
|
||||
|
||||
__all__ = [
|
||||
# Embedding services
|
||||
'get_embedding_model',
|
||||
'chunk_transcription',
|
||||
'generate_embeddings',
|
||||
'serialize_embedding',
|
||||
'deserialize_embedding',
|
||||
'get_accessible_recording_ids',
|
||||
'process_recording_chunks',
|
||||
'basic_text_search_chunks',
|
||||
'semantic_search_chunks',
|
||||
# LLM services
|
||||
'is_gpt5_model',
|
||||
'is_using_openai_api',
|
||||
'call_llm_completion',
|
||||
'call_chat_completion',
|
||||
'chat_client',
|
||||
'format_api_error_message',
|
||||
# Document services
|
||||
'process_markdown_to_docx',
|
||||
# Retention services
|
||||
'is_recording_exempt_from_deletion',
|
||||
'get_retention_days_for_recording',
|
||||
'process_auto_deletion',
|
||||
]
|
||||
211
src/services/audit.py
Normal file
211
src/services/audit.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Central audit service for Loi 25 compliance.
|
||||
|
||||
Wraps AccessLog and AuthLog with request context helpers.
|
||||
All logging is gated behind ENABLE_AUDIT_LOG env var.
|
||||
|
||||
NOTE: ENABLE_AUDIT_LOG is read once at import time. Changing the env var
|
||||
requires an application restart to take effect.
|
||||
|
||||
NOTE: Audit helpers do NOT commit the session — the caller's transaction
|
||||
will persist the log entry when it commits. This avoids interfering with
|
||||
ongoing transactions (e.g. a delete operation that audits then deletes).
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request, has_request_context
|
||||
from flask_login import current_user
|
||||
|
||||
from src.database import db
|
||||
from src.models.access_log import AccessLog
|
||||
from src.models.auth_log import AuthLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Read once at import — requires restart to change.
|
||||
ENABLE_AUDIT_LOG = os.environ.get('ENABLE_AUDIT_LOG', 'false').lower() == 'true'
|
||||
|
||||
|
||||
def is_audit_enabled():
|
||||
"""Check if audit logging is enabled."""
|
||||
return ENABLE_AUDIT_LOG
|
||||
|
||||
|
||||
def _get_request_context():
|
||||
"""Extract IP address and user agent from current request."""
|
||||
if not has_request_context():
|
||||
return None, None
|
||||
ip = request.remote_addr
|
||||
ua = request.headers.get('User-Agent', '')[:500]
|
||||
return ip, ua
|
||||
|
||||
|
||||
def _get_current_user_id():
|
||||
"""Get current user ID if authenticated."""
|
||||
if has_request_context() and current_user and current_user.is_authenticated:
|
||||
return current_user.id
|
||||
return None
|
||||
|
||||
|
||||
# --- Access logging helpers ---
|
||||
|
||||
_DEDUP_WINDOW_SECONDS = 300 # 5 minutes
|
||||
|
||||
|
||||
def _is_recent_duplicate(action, resource_type, resource_id, user_id):
|
||||
"""Return True if the same user already logged this action on this resource in the last 5 min.
|
||||
|
||||
Prevents unbounded log growth for high-frequency read operations (e.g. view on every GET).
|
||||
"""
|
||||
if user_id is None or resource_id is None:
|
||||
return False
|
||||
cutoff = datetime.utcnow() - timedelta(seconds=_DEDUP_WINDOW_SECONDS)
|
||||
return AccessLog.query.filter(
|
||||
AccessLog.user_id == user_id,
|
||||
AccessLog.action == action,
|
||||
AccessLog.resource_type == resource_type,
|
||||
AccessLog.resource_id == resource_id,
|
||||
AccessLog.timestamp >= cutoff,
|
||||
).first() is not None
|
||||
|
||||
|
||||
def audit_access(action, resource_type, resource_id=None, user_id=None, status='success', details=None):
|
||||
"""Log a data access event if audit is enabled.
|
||||
|
||||
Does NOT commit — the log is added to the current session and will be
|
||||
persisted when the caller (or Flask teardown) commits.
|
||||
|
||||
View events are deduplicated: the same user viewing the same resource within
|
||||
5 minutes is logged only once to avoid unbounded log growth.
|
||||
"""
|
||||
if not ENABLE_AUDIT_LOG:
|
||||
return None
|
||||
try:
|
||||
ip, ua = _get_request_context()
|
||||
if user_id is None:
|
||||
user_id = _get_current_user_id()
|
||||
if action == 'view' and _is_recent_duplicate(action, resource_type, resource_id, user_id):
|
||||
return None
|
||||
log = AccessLog.log_access(
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
user_id=user_id,
|
||||
status=status,
|
||||
details=details,
|
||||
ip_address=ip,
|
||||
user_agent=ua,
|
||||
)
|
||||
return log
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write access audit log: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def audit_view(resource_type, resource_id, **kwargs):
|
||||
"""Log a view access."""
|
||||
return audit_access('view', resource_type, resource_id, **kwargs)
|
||||
|
||||
|
||||
def audit_download(resource_type, resource_id, **kwargs):
|
||||
"""Log a download access."""
|
||||
return audit_access('download', resource_type, resource_id, **kwargs)
|
||||
|
||||
|
||||
def audit_edit(resource_type, resource_id, **kwargs):
|
||||
"""Log an edit."""
|
||||
return audit_access('edit', resource_type, resource_id, **kwargs)
|
||||
|
||||
|
||||
def audit_delete(resource_type, resource_id, **kwargs):
|
||||
"""Log a deletion."""
|
||||
return audit_access('delete', resource_type, resource_id, **kwargs)
|
||||
|
||||
|
||||
def audit_export(resource_type, resource_id, **kwargs):
|
||||
"""Log a data export."""
|
||||
return audit_access('export', resource_type, resource_id, **kwargs)
|
||||
|
||||
|
||||
# --- Auth logging helpers ---
|
||||
# Auth events (login/logout/failed) are standalone operations, so they
|
||||
# commit their own transaction since the caller typically redirects after.
|
||||
|
||||
def _audit_auth(func, *args, **kwargs):
|
||||
"""Wrapper for auth audit helpers that commits independently."""
|
||||
if not ENABLE_AUDIT_LOG:
|
||||
return None
|
||||
try:
|
||||
ip, ua = _get_request_context()
|
||||
log = func(*args, ip_address=ip, user_agent=ua, **kwargs)
|
||||
db.session.commit()
|
||||
return log
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write auth audit log: {e}")
|
||||
db.session.rollback()
|
||||
return None
|
||||
|
||||
|
||||
def audit_login(user_id, details=None):
|
||||
"""Log a successful login."""
|
||||
return _audit_auth(AuthLog.log_login, user_id, details=details)
|
||||
|
||||
|
||||
def audit_logout(user_id=None):
|
||||
"""Log a logout."""
|
||||
if user_id is None:
|
||||
user_id = _get_current_user_id()
|
||||
return _audit_auth(AuthLog.log_logout, user_id)
|
||||
|
||||
|
||||
def audit_failed_login(details=None):
|
||||
"""Log a failed login."""
|
||||
return _audit_auth(AuthLog.log_failed_login, details=details)
|
||||
|
||||
|
||||
def audit_register(user_id):
|
||||
"""Log a registration."""
|
||||
return _audit_auth(AuthLog.log_register, user_id)
|
||||
|
||||
|
||||
def audit_password_change(user_id, details=None):
|
||||
"""Log a password change."""
|
||||
return _audit_auth(AuthLog.log_password_change, user_id, details=details)
|
||||
|
||||
|
||||
def audit_password_reset(user_id):
|
||||
"""Log a password reset."""
|
||||
return _audit_auth(AuthLog.log_password_reset, user_id)
|
||||
|
||||
|
||||
def audit_sso_login(user_id, details=None):
|
||||
"""Log an SSO login."""
|
||||
return _audit_auth(AuthLog.log_sso_login, user_id, details=details)
|
||||
|
||||
|
||||
# --- Query helpers for admin ---
|
||||
|
||||
def get_access_logs(page=1, per_page=50, user_id=None, resource_type=None, resource_id=None, action=None):
|
||||
"""Query access logs with pagination and filters."""
|
||||
query = AccessLog.query.order_by(AccessLog.timestamp.desc())
|
||||
if user_id is not None:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if resource_type is not None:
|
||||
query = query.filter_by(resource_type=resource_type)
|
||||
if resource_id is not None:
|
||||
query = query.filter_by(resource_id=resource_id)
|
||||
if action is not None:
|
||||
query = query.filter_by(action=action)
|
||||
return query.paginate(page=page, per_page=per_page, error_out=False)
|
||||
|
||||
|
||||
def get_auth_logs(page=1, per_page=50, user_id=None, action=None):
|
||||
"""Query auth logs with pagination and filters."""
|
||||
query = AuthLog.query.order_by(AuthLog.timestamp.desc())
|
||||
if user_id is not None:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if action is not None:
|
||||
query = query.filter_by(action=action)
|
||||
return query.paginate(page=page, per_page=per_page, error_out=False)
|
||||
102
src/services/calendar.py
Normal file
102
src/services/calendar.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Calendar/ICS file generation services.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def generate_ics_content(event):
|
||||
"""Generate ICS calendar file content for an event."""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Generate unique ID for the event
|
||||
uid = f"{event.id}-{uuid.uuid4()}@speakr.app"
|
||||
|
||||
# Format dates in iCalendar format (YYYYMMDDTHHMMSS)
|
||||
def format_ical_date(dt):
|
||||
if dt:
|
||||
return dt.strftime('%Y%m%dT%H%M%S')
|
||||
return None
|
||||
|
||||
# Start building ICS content
|
||||
lines = [
|
||||
'BEGIN:VCALENDAR',
|
||||
'VERSION:2.0',
|
||||
'PRODID:-//Speakr//Event Export//EN',
|
||||
'CALSCALE:GREGORIAN',
|
||||
'METHOD:PUBLISH',
|
||||
'BEGIN:VEVENT',
|
||||
f'UID:{uid}',
|
||||
f'DTSTAMP:{format_ical_date(datetime.utcnow())}',
|
||||
]
|
||||
|
||||
# Add event details
|
||||
if event.start_datetime:
|
||||
lines.append(f'DTSTART:{format_ical_date(event.start_datetime)}')
|
||||
|
||||
if event.end_datetime:
|
||||
lines.append(f'DTEND:{format_ical_date(event.end_datetime)}')
|
||||
elif event.start_datetime:
|
||||
# If no end time, default to 1 hour after start
|
||||
end_time = event.start_datetime + timedelta(hours=1)
|
||||
lines.append(f'DTEND:{format_ical_date(end_time)}')
|
||||
|
||||
# Add title and description
|
||||
lines.append(f'SUMMARY:{escape_ical_text(event.title)}')
|
||||
|
||||
if event.description:
|
||||
lines.append(f'DESCRIPTION:{escape_ical_text(event.description)}')
|
||||
|
||||
# Add location if available
|
||||
if event.location:
|
||||
lines.append(f'LOCATION:{escape_ical_text(event.location)}')
|
||||
|
||||
# Add attendees if available
|
||||
if event.attendees:
|
||||
try:
|
||||
attendees_list = json.loads(event.attendees)
|
||||
for attendee in attendees_list:
|
||||
if attendee:
|
||||
lines.append(f'ATTENDEE:CN={escape_ical_text(attendee)}:mailto:{attendee.replace(" ", ".").lower()}@example.com')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add reminder/alarm if specified
|
||||
if event.reminder_minutes and event.reminder_minutes > 0:
|
||||
lines.extend([
|
||||
'BEGIN:VALARM',
|
||||
'TRIGGER:-PT{}M'.format(event.reminder_minutes),
|
||||
'ACTION:DISPLAY',
|
||||
f'DESCRIPTION:Reminder: {escape_ical_text(event.title)}',
|
||||
'END:VALARM'
|
||||
])
|
||||
|
||||
# Close event and calendar
|
||||
lines.extend([
|
||||
'STATUS:CONFIRMED',
|
||||
'TRANSP:OPAQUE',
|
||||
'END:VEVENT',
|
||||
'END:VCALENDAR'
|
||||
])
|
||||
|
||||
return '\r\n'.join(lines)
|
||||
|
||||
|
||||
|
||||
def escape_ical_text(text):
|
||||
"""Escape special characters for iCalendar format."""
|
||||
if not text:
|
||||
return ''
|
||||
# Escape special characters
|
||||
text = str(text)
|
||||
text = text.replace('\\', '\\\\')
|
||||
text = text.replace(',', '\\,')
|
||||
text = text.replace(';', '\\;')
|
||||
text = text.replace('\n', '\\n')
|
||||
return text
|
||||
|
||||
|
||||
|
||||
296
src/services/document.py
Normal file
296
src/services/document.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Document processing and conversion services.
|
||||
"""
|
||||
|
||||
import re
|
||||
from docx import Document
|
||||
from docx.shared import Pt, RGBColor
|
||||
|
||||
|
||||
|
||||
def process_markdown_to_docx(doc, content):
|
||||
"""Convert markdown content to properly formatted Word document elements.
|
||||
|
||||
Supports:
|
||||
- Tables (markdown pipe tables)
|
||||
- Headings (# ## ###)
|
||||
- Bold text (**text**)
|
||||
- Italic text (*text* or _text_)
|
||||
- Bold italic (***text***)
|
||||
- Inline code (`code`)
|
||||
- Code blocks (```code```)
|
||||
- Strikethrough (~~text~~)
|
||||
- Links ([text](url))
|
||||
- Bullet lists (- or *)
|
||||
- Numbered lists (1. 2. 3.)
|
||||
- Horizontal rules (--- or ***)
|
||||
"""
|
||||
from docx.shared import RGBColor, Pt
|
||||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
|
||||
from docx.oxml.ns import qn
|
||||
import re
|
||||
|
||||
def ensure_unicode_font(run, text):
|
||||
"""Ensure the run uses a font that supports the characters in the text."""
|
||||
# Check if text contains non-ASCII characters
|
||||
try:
|
||||
text.encode('ascii')
|
||||
# Text is pure ASCII, no special font needed
|
||||
except UnicodeEncodeError:
|
||||
# Text contains non-ASCII characters, use a font with better Unicode support
|
||||
# Use Arial for broad compatibility - it has good Unicode support on most systems
|
||||
run.font.name = 'Arial'
|
||||
# Set the East Asian font for CJK (Chinese, Japanese, Korean) text
|
||||
# This ensures proper rendering in Word
|
||||
r = run._element
|
||||
r.rPr.rFonts.set(qn('w:eastAsia'), 'Arial')
|
||||
return run
|
||||
|
||||
def add_formatted_run(paragraph, text):
|
||||
"""Add a run with inline formatting to a paragraph."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Pattern for all inline formatting
|
||||
# Order matters: check triple asterisk before double/single
|
||||
patterns = [
|
||||
(r'\*\*\*(.*?)\*\*\*', lambda p, t: (lambda r: (setattr(r, 'bold', True), setattr(r, 'italic', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Bold italic
|
||||
(r'\*\*(.*?)\*\*', lambda p, t: (lambda r: (setattr(r, 'bold', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Bold
|
||||
(r'(?<!\*)\*(?!\*)(.*?)\*(?!\*)', lambda p, t: (lambda r: (setattr(r, 'italic', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Italic with *
|
||||
(r'\b_(.*?)_\b', lambda p, t: (lambda r: (setattr(r, 'italic', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Italic with _
|
||||
(r'~~(.*?)~~', lambda p, t: (lambda r: (setattr(r, 'strike', True), ensure_unicode_font(r, t)))(p.add_run(t))), # Strikethrough
|
||||
(r'`([^`]+)`', lambda p, t: add_code_run(p, t)), # Inline code
|
||||
(r'\[([^\]]+)\]\(([^)]+)\)', lambda p, t, u: add_link_run(p, t, u)), # Links
|
||||
]
|
||||
|
||||
def add_code_run(para, text):
|
||||
"""Add inline code with monospace font and background."""
|
||||
run = para.add_run(text)
|
||||
run.font.name = 'Courier New'
|
||||
run.font.size = Pt(10)
|
||||
run.font.color.rgb = RGBColor(220, 20, 60) # Crimson color for code
|
||||
# Check if we need Unicode support for code
|
||||
try:
|
||||
text.encode('ascii')
|
||||
except UnicodeEncodeError:
|
||||
# Use Consolas as fallback for better Unicode support in monospace
|
||||
r = run._element
|
||||
r.rPr.rFonts.set(qn('w:eastAsia'), 'Consolas')
|
||||
return run
|
||||
|
||||
def add_link_run(para, text, url):
|
||||
"""Add a hyperlink-styled run (note: actual hyperlinks require more complex handling)."""
|
||||
full_text = f"{text} ({url})"
|
||||
run = para.add_run(full_text)
|
||||
run.font.color.rgb = RGBColor(0, 0, 255) # Blue color for links
|
||||
run.font.underline = True
|
||||
ensure_unicode_font(run, full_text)
|
||||
return run
|
||||
|
||||
# Process the text with all patterns
|
||||
remaining_text = text
|
||||
while remaining_text:
|
||||
earliest_match = None
|
||||
earliest_pos = len(remaining_text)
|
||||
matched_pattern = None
|
||||
|
||||
# Find the earliest matching pattern
|
||||
for pattern, handler in patterns:
|
||||
match = re.search(pattern, remaining_text)
|
||||
if match and match.start() < earliest_pos:
|
||||
earliest_match = match
|
||||
earliest_pos = match.start()
|
||||
matched_pattern = handler
|
||||
|
||||
if earliest_match:
|
||||
# Add text before the match
|
||||
if earliest_pos > 0:
|
||||
run = paragraph.add_run(remaining_text[:earliest_pos])
|
||||
ensure_unicode_font(run, remaining_text[:earliest_pos])
|
||||
|
||||
# Apply formatting for the matched text
|
||||
if '[' in earliest_match.group(0) and '](' in earliest_match.group(0):
|
||||
# Special handling for links (two groups)
|
||||
matched_pattern(paragraph, earliest_match.group(1), earliest_match.group(2))
|
||||
else:
|
||||
matched_pattern(paragraph, earliest_match.group(1))
|
||||
|
||||
# Continue with remaining text
|
||||
remaining_text = remaining_text[earliest_match.end():]
|
||||
else:
|
||||
# No more patterns, add the rest as plain text
|
||||
run = paragraph.add_run(remaining_text)
|
||||
ensure_unicode_font(run, remaining_text)
|
||||
break
|
||||
|
||||
def parse_table(lines, start_idx):
|
||||
"""Parse a markdown table starting at the given index."""
|
||||
if start_idx >= len(lines):
|
||||
return None, start_idx
|
||||
|
||||
# Check if this looks like a table
|
||||
if '|' not in lines[start_idx]:
|
||||
return None, start_idx
|
||||
|
||||
table_data = []
|
||||
idx = start_idx
|
||||
|
||||
while idx < len(lines) and '|' in lines[idx]:
|
||||
# Skip separator lines
|
||||
if re.match(r'^[\s\|\-:]+$', lines[idx]):
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
# Parse cells
|
||||
cells = [cell.strip() for cell in lines[idx].split('|')]
|
||||
# Remove empty cells at start and end
|
||||
if cells and not cells[0]:
|
||||
cells = cells[1:]
|
||||
if cells and not cells[-1]:
|
||||
cells = cells[:-1]
|
||||
|
||||
if cells:
|
||||
table_data.append(cells)
|
||||
idx += 1
|
||||
|
||||
if table_data:
|
||||
return table_data, idx
|
||||
return None, start_idx
|
||||
|
||||
# Split content into lines
|
||||
lines = content.split('\n')
|
||||
i = 0
|
||||
in_code_block = False
|
||||
code_block_content = []
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# Handle code blocks
|
||||
if line.strip().startswith('```'):
|
||||
if not in_code_block:
|
||||
in_code_block = True
|
||||
code_block_content = []
|
||||
else:
|
||||
# End of code block - add it as preformatted text
|
||||
in_code_block = False
|
||||
if code_block_content:
|
||||
p = doc.add_paragraph()
|
||||
p.style = 'Normal'
|
||||
code_text = '\n'.join(code_block_content)
|
||||
run = p.add_run(code_text)
|
||||
run.font.name = 'Courier New'
|
||||
run.font.size = Pt(10)
|
||||
run.font.color.rgb = RGBColor(64, 64, 64)
|
||||
# Check if we need Unicode support for code blocks
|
||||
try:
|
||||
code_text.encode('ascii')
|
||||
except UnicodeEncodeError:
|
||||
r = run._element
|
||||
r.rPr.rFonts.set(qn('w:eastAsia'), 'Consolas')
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_code_block:
|
||||
code_block_content.append(line)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for table
|
||||
table_data, end_idx = parse_table(lines, i)
|
||||
if table_data:
|
||||
# Create Word table
|
||||
table = doc.add_table(rows=len(table_data), cols=len(table_data[0]))
|
||||
table.style = 'Table Grid'
|
||||
|
||||
# Populate table
|
||||
for row_idx, row_data in enumerate(table_data):
|
||||
for col_idx, cell_text in enumerate(row_data):
|
||||
if col_idx < len(table.rows[row_idx].cells):
|
||||
cell = table.rows[row_idx].cells[col_idx]
|
||||
# Clear existing paragraphs and add new one
|
||||
cell.text = ""
|
||||
p = cell.add_paragraph()
|
||||
add_formatted_run(p, cell_text)
|
||||
# Make header row bold
|
||||
if row_idx == 0:
|
||||
for run in p.runs:
|
||||
run.bold = True
|
||||
|
||||
doc.add_paragraph('') # Space after table
|
||||
i = end_idx
|
||||
continue
|
||||
|
||||
line = line.rstrip()
|
||||
|
||||
# Skip empty lines
|
||||
if not line:
|
||||
doc.add_paragraph('')
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Horizontal rule
|
||||
if re.match(r'^(\*{3,}|-{3,}|_{3,})$', line.strip()):
|
||||
p = doc.add_paragraph('─' * 50)
|
||||
p.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Headings
|
||||
if line.startswith('# '):
|
||||
doc.add_heading(line[2:], 1)
|
||||
elif line.startswith('## '):
|
||||
doc.add_heading(line[3:], 2)
|
||||
elif line.startswith('### '):
|
||||
doc.add_heading(line[4:], 3)
|
||||
elif line.startswith('#### '):
|
||||
doc.add_heading(line[5:], 4)
|
||||
# Bullet points
|
||||
elif line.lstrip().startswith('- ') or line.lstrip().startswith('* '):
|
||||
# Get the indentation level
|
||||
indent = len(line) - len(line.lstrip())
|
||||
bullet_text = line.lstrip()[2:]
|
||||
p = doc.add_paragraph(style='List Bullet')
|
||||
# Add indentation if nested
|
||||
if indent > 0:
|
||||
p.paragraph_format.left_indent = Pt(indent * 10)
|
||||
add_formatted_run(p, bullet_text)
|
||||
# Numbered lists
|
||||
elif re.match(r'^\s*\d+\.', line):
|
||||
match = re.match(r'^(\s*)(\d+)\.\s*(.*)', line)
|
||||
if match:
|
||||
indent = len(match.group(1))
|
||||
list_text = match.group(3)
|
||||
p = doc.add_paragraph(style='List Number')
|
||||
if indent > 0:
|
||||
p.paragraph_format.left_indent = Pt(indent * 10)
|
||||
add_formatted_run(p, list_text)
|
||||
# Blockquote
|
||||
elif line.startswith('> '):
|
||||
p = doc.add_paragraph()
|
||||
p.paragraph_format.left_indent = Pt(30)
|
||||
add_formatted_run(p, line[2:])
|
||||
# Add a gray color to indicate quote
|
||||
for run in p.runs:
|
||||
run.font.color.rgb = RGBColor(100, 100, 100)
|
||||
else:
|
||||
# Regular paragraph
|
||||
p = doc.add_paragraph()
|
||||
add_formatted_run(p, line)
|
||||
|
||||
i += 1
|
||||
|
||||
# --- Database Models ---
|
||||
# --- Database Models ---
|
||||
# Models have been extracted to src/models/ and imported at the top of this file
|
||||
|
||||
# --- Forms for Authentication ---
|
||||
# --- Custom Password Validator ---
|
||||
# password_check utility has been extracted to src/utils/security.py
|
||||
|
||||
|
||||
# --- Blueprint Registration ---
|
||||
# Import and register all blueprints for modular route organization
|
||||
|
||||
|
||||
|
||||
443
src/services/email.py
Normal file
443
src/services/email.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Email service for verification and password reset.
|
||||
|
||||
This module provides email functionality using Python's built-in smtplib.
|
||||
All email features are opt-in via environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import smtplib
|
||||
import logging
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
|
||||
from flask import current_app, url_for
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Token expiry times
|
||||
EMAIL_VERIFICATION_EXPIRY = 24 * 60 * 60 # 24 hours in seconds
|
||||
PASSWORD_RESET_EXPIRY = 1 * 60 * 60 # 1 hour in seconds
|
||||
|
||||
|
||||
def get_email_config():
|
||||
"""Get email configuration from environment variables."""
|
||||
return {
|
||||
'enabled': os.environ.get('ENABLE_EMAIL_VERIFICATION', 'false').lower() == 'true',
|
||||
'required': os.environ.get('REQUIRE_EMAIL_VERIFICATION', 'false').lower() == 'true',
|
||||
'smtp_host': os.environ.get('SMTP_HOST', ''),
|
||||
'smtp_port': int(os.environ.get('SMTP_PORT', '587')),
|
||||
'smtp_username': os.environ.get('SMTP_USERNAME', ''),
|
||||
'smtp_password': os.environ.get('SMTP_PASSWORD', ''),
|
||||
'smtp_use_tls': os.environ.get('SMTP_USE_TLS', 'true').lower() == 'true',
|
||||
'smtp_use_ssl': os.environ.get('SMTP_USE_SSL', 'false').lower() == 'true',
|
||||
'from_address': os.environ.get('SMTP_FROM_ADDRESS', 'noreply@yourdomain.com'),
|
||||
'from_name': os.environ.get('SMTP_FROM_NAME', 'Speakr'),
|
||||
}
|
||||
|
||||
|
||||
def is_email_verification_enabled() -> bool:
|
||||
"""Check if email verification is enabled."""
|
||||
return get_email_config()['enabled']
|
||||
|
||||
|
||||
def is_email_verification_required() -> bool:
|
||||
"""Check if email verification is required for login."""
|
||||
config = get_email_config()
|
||||
return config['enabled'] and config['required']
|
||||
|
||||
|
||||
def is_smtp_configured() -> bool:
|
||||
"""Check if SMTP settings are properly configured."""
|
||||
config = get_email_config()
|
||||
return bool(config['smtp_host'] and config['smtp_username'] and config['smtp_password'])
|
||||
|
||||
|
||||
def get_serializer(salt: str) -> URLSafeTimedSerializer:
|
||||
"""Get a URL-safe timed serializer for token generation."""
|
||||
secret_key = current_app.config.get('SECRET_KEY', 'default-dev-key')
|
||||
return URLSafeTimedSerializer(secret_key, salt=salt)
|
||||
|
||||
|
||||
def generate_verification_token(user_id: int) -> str:
|
||||
"""Generate an email verification token."""
|
||||
serializer = get_serializer('email-verification')
|
||||
return serializer.dumps(user_id)
|
||||
|
||||
|
||||
def generate_password_reset_token(user_id: int) -> str:
|
||||
"""Generate a password reset token."""
|
||||
serializer = get_serializer('password-reset')
|
||||
return serializer.dumps(user_id)
|
||||
|
||||
|
||||
def verify_email_token(token: str) -> Optional[int]:
|
||||
"""
|
||||
Verify an email verification token.
|
||||
|
||||
Returns the user_id if valid, None otherwise.
|
||||
"""
|
||||
serializer = get_serializer('email-verification')
|
||||
try:
|
||||
user_id = serializer.loads(token, max_age=EMAIL_VERIFICATION_EXPIRY)
|
||||
return user_id
|
||||
except SignatureExpired:
|
||||
logger.warning("Email verification token expired")
|
||||
return None
|
||||
except BadSignature:
|
||||
logger.warning("Invalid email verification token")
|
||||
return None
|
||||
|
||||
|
||||
def verify_reset_token(token: str) -> Optional[int]:
|
||||
"""
|
||||
Verify a password reset token.
|
||||
|
||||
Returns the user_id if valid, None otherwise.
|
||||
"""
|
||||
serializer = get_serializer('password-reset')
|
||||
try:
|
||||
user_id = serializer.loads(token, max_age=PASSWORD_RESET_EXPIRY)
|
||||
return user_id
|
||||
except SignatureExpired:
|
||||
logger.warning("Password reset token expired")
|
||||
return None
|
||||
except BadSignature:
|
||||
logger.warning("Invalid password reset token")
|
||||
return None
|
||||
|
||||
|
||||
def _send_email(to_email: str, subject: str, html_body: str, text_body: str = None) -> bool:
|
||||
"""
|
||||
Send an email using SMTP.
|
||||
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
config = get_email_config()
|
||||
|
||||
if not is_smtp_configured():
|
||||
logger.error("SMTP is not configured. Cannot send email.")
|
||||
return False
|
||||
|
||||
try:
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['Subject'] = subject
|
||||
msg['From'] = f"{config['from_name']} <{config['from_address']}>"
|
||||
msg['To'] = to_email
|
||||
|
||||
# Add plain text version
|
||||
if text_body:
|
||||
part1 = MIMEText(text_body, 'plain')
|
||||
msg.attach(part1)
|
||||
|
||||
# Add HTML version
|
||||
part2 = MIMEText(html_body, 'html')
|
||||
msg.attach(part2)
|
||||
|
||||
# Connect to SMTP server
|
||||
if config['smtp_use_ssl']:
|
||||
server = smtplib.SMTP_SSL(config['smtp_host'], config['smtp_port'])
|
||||
else:
|
||||
server = smtplib.SMTP(config['smtp_host'], config['smtp_port'])
|
||||
if config['smtp_use_tls']:
|
||||
server.starttls()
|
||||
|
||||
server.login(config['smtp_username'], config['smtp_password'])
|
||||
server.sendmail(config['from_address'], to_email, msg.as_string())
|
||||
server.quit()
|
||||
|
||||
logger.info(f"Email sent successfully to {to_email}")
|
||||
return True
|
||||
|
||||
except smtplib.SMTPAuthenticationError as e:
|
||||
logger.error(f"SMTP authentication failed: {e}")
|
||||
return False
|
||||
except smtplib.SMTPException as e:
|
||||
logger.error(f"SMTP error sending email: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _get_email_template(content_html: str, content_text: str, subject: str) -> tuple[str, str]:
|
||||
"""
|
||||
Wrap content in the Speakr email template.
|
||||
|
||||
Returns (html_body, text_body)
|
||||
"""
|
||||
# Get the base URL for the logo
|
||||
try:
|
||||
logo_url = url_for('static', filename='img/icon-192x192.png', _external=True)
|
||||
except RuntimeError:
|
||||
# Outside of request context, use a placeholder
|
||||
logo_url = ""
|
||||
|
||||
html_body = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
</head>
|
||||
<body style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; line-height: 1.6; color: #1f2937; margin: 0; padding: 0; background-color: #e8eaed;">
|
||||
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%" style="background-color: #e8eaed;">
|
||||
<tr>
|
||||
<td style="padding: 40px 20px;">
|
||||
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="600" style="max-width: 600px; margin: 0 auto;">
|
||||
<!-- Header -->
|
||||
<tr>
|
||||
<td style="background-color: #2563eb; padding: 32px 40px; border-radius: 12px 12px 0 0;">
|
||||
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%">
|
||||
<tr>
|
||||
<td>
|
||||
<!-- Logo and Brand -->
|
||||
<table role="presentation" cellspacing="0" cellpadding="0" border="0">
|
||||
<tr>
|
||||
<td style="vertical-align: middle; padding-right: 12px;">
|
||||
<img src="{logo_url}" alt="Speakr" width="44" height="44" style="display: block; border-radius: 8px;">
|
||||
</td>
|
||||
<td style="vertical-align: middle;">
|
||||
<h1 style="color: #ffffff; margin: 0; font-size: 28px; font-weight: 700; letter-spacing: -0.5px;">Speakr</h1>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding-top: 8px;">
|
||||
<p style="color: rgba(255,255,255,0.85); margin: 0; font-size: 14px;">AI-Powered Audio Transcription</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Content -->
|
||||
<tr>
|
||||
<td style="background-color: #ffffff; padding: 40px; border-left: 1px solid #e5e7eb; border-right: 1px solid #e5e7eb;">
|
||||
{content_html}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Footer -->
|
||||
<tr>
|
||||
<td style="background-color: #f8f9fa; padding: 24px 40px; border-radius: 0 0 12px 12px; border: 1px solid #e5e7eb; border-top: none;">
|
||||
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%">
|
||||
<tr>
|
||||
<td style="text-align: center;">
|
||||
<p style="color: #6b7280; font-size: 12px; margin: 0 0 8px 0;">
|
||||
This email was sent by Speakr. If you have questions, please contact your administrator.
|
||||
</p>
|
||||
<p style="color: #9ca3af; font-size: 11px; margin: 0;">
|
||||
© {datetime.utcnow().year} Speakr · AI-Powered Audio Transcription
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
text_body = f"""
|
||||
{subject}
|
||||
{'=' * len(subject)}
|
||||
|
||||
{content_text}
|
||||
|
||||
---
|
||||
This email was sent by Speakr - AI-Powered Audio Transcription.
|
||||
If you have questions, please contact your administrator.
|
||||
"""
|
||||
|
||||
return html_body, text_body
|
||||
|
||||
|
||||
def send_verification_email(user) -> bool:
|
||||
"""
|
||||
Send a verification email to a user.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
|
||||
Returns True if email was sent successfully, False otherwise.
|
||||
"""
|
||||
from src.database import db
|
||||
|
||||
if not is_email_verification_enabled():
|
||||
logger.debug("Email verification is disabled")
|
||||
return False
|
||||
|
||||
if not is_smtp_configured():
|
||||
logger.warning("Cannot send verification email: SMTP not configured")
|
||||
return False
|
||||
|
||||
# Generate token and store it
|
||||
token = generate_verification_token(user.id)
|
||||
user.email_verification_token = token
|
||||
user.email_verification_sent_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Build verification URL
|
||||
verify_url = url_for('auth.verify_email', token=token, _external=True)
|
||||
|
||||
subject = "Verify your email address - Speakr"
|
||||
|
||||
content_html = f"""
|
||||
<h2 style="color: #1f2937; margin: 0 0 24px 0; font-size: 24px; font-weight: 600;">Verify Your Email Address</h2>
|
||||
|
||||
<p style="color: #374151; margin: 0 0 16px 0; font-size: 16px;">Hi {user.username},</p>
|
||||
|
||||
<p style="color: #374151; margin: 0 0 24px 0; font-size: 16px;">
|
||||
Welcome to Speakr! To complete your registration and start transcribing your audio recordings, please verify your email address.
|
||||
</p>
|
||||
|
||||
<div style="text-align: center; margin: 32px 0;">
|
||||
<a href="{verify_url}" style="display: inline-block; background-color: #2563eb; color: #ffffff; text-decoration: none; padding: 14px 32px; border-radius: 8px; font-weight: 600; font-size: 16px;">Verify Email Address</a>
|
||||
</div>
|
||||
|
||||
<p style="color: #6b7280; font-size: 14px; margin: 24px 0 8px 0;">Or copy and paste this link into your browser:</p>
|
||||
<p style="word-break: break-all; color: #2563eb; font-size: 14px; margin: 0; padding: 12px; background-color: #f3f4f6; border-radius: 6px;">{verify_url}</p>
|
||||
|
||||
<div style="margin-top: 32px; padding-top: 24px; border-top: 1px solid #e5e7eb;">
|
||||
<p style="color: #9ca3af; font-size: 13px; margin: 0;">
|
||||
<strong>This link will expire in 24 hours.</strong><br>
|
||||
If you didn't create an account on Speakr, you can safely ignore this email.
|
||||
</p>
|
||||
</div>
|
||||
"""
|
||||
|
||||
content_text = f"""Hi {user.username},
|
||||
|
||||
Welcome to Speakr! To complete your registration and start transcribing your audio recordings, please verify your email address.
|
||||
|
||||
Click here to verify: {verify_url}
|
||||
|
||||
This link will expire in 24 hours.
|
||||
|
||||
If you didn't create an account on Speakr, you can safely ignore this email."""
|
||||
|
||||
html_body, text_body = _get_email_template(content_html, content_text, subject)
|
||||
return _send_email(user.email, subject, html_body, text_body)
|
||||
|
||||
|
||||
def send_password_reset_email(user) -> bool:
|
||||
"""
|
||||
Send a password reset email to a user.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
|
||||
Returns True if email was sent successfully, False otherwise.
|
||||
"""
|
||||
from src.database import db
|
||||
|
||||
if not is_smtp_configured():
|
||||
logger.warning("Cannot send password reset email: SMTP not configured")
|
||||
return False
|
||||
|
||||
# Generate token and store it
|
||||
token = generate_password_reset_token(user.id)
|
||||
user.password_reset_token = token
|
||||
user.password_reset_sent_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Build reset URL
|
||||
reset_url = url_for('auth.reset_password', token=token, _external=True)
|
||||
|
||||
subject = "Reset your password - Speakr"
|
||||
|
||||
content_html = f"""
|
||||
<h2 style="color: #1f2937; margin: 0 0 24px 0; font-size: 24px; font-weight: 600;">Reset Your Password</h2>
|
||||
|
||||
<p style="color: #374151; margin: 0 0 16px 0; font-size: 16px;">Hi {user.username},</p>
|
||||
|
||||
<p style="color: #374151; margin: 0 0 24px 0; font-size: 16px;">
|
||||
We received a request to reset your Speakr account password. Click the button below to create a new password.
|
||||
</p>
|
||||
|
||||
<div style="text-align: center; margin: 32px 0;">
|
||||
<a href="{reset_url}" style="display: inline-block; background-color: #2563eb; color: #ffffff; text-decoration: none; padding: 14px 32px; border-radius: 8px; font-weight: 600; font-size: 16px;">Reset Password</a>
|
||||
</div>
|
||||
|
||||
<p style="color: #6b7280; font-size: 14px; margin: 24px 0 8px 0;">Or copy and paste this link into your browser:</p>
|
||||
<p style="word-break: break-all; color: #2563eb; font-size: 14px; margin: 0; padding: 12px; background-color: #f3f4f6; border-radius: 6px;">{reset_url}</p>
|
||||
|
||||
<div style="margin-top: 32px; padding-top: 24px; border-top: 1px solid #e5e7eb;">
|
||||
<table role="presentation" cellspacing="0" cellpadding="0" border="0" width="100%">
|
||||
<tr>
|
||||
<td style="width: 24px; vertical-align: top; padding-right: 12px;">
|
||||
<span style="font-size: 18px;">⚠️</span>
|
||||
</td>
|
||||
<td>
|
||||
<p style="color: #9ca3af; font-size: 13px; margin: 0;">
|
||||
<strong style="color: #6b7280;">This link will expire in 1 hour.</strong><br>
|
||||
If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
"""
|
||||
|
||||
content_text = f"""Hi {user.username},
|
||||
|
||||
We received a request to reset your Speakr account password. Click the link below to create a new password:
|
||||
|
||||
{reset_url}
|
||||
|
||||
This link will expire in 1 hour.
|
||||
|
||||
If you didn't request a password reset, you can safely ignore this email. Your password will remain unchanged."""
|
||||
|
||||
html_body, text_body = _get_email_template(content_html, content_text, subject)
|
||||
return _send_email(user.email, subject, html_body, text_body)
|
||||
|
||||
|
||||
def can_resend_verification(user) -> tuple[bool, Optional[int]]:
|
||||
"""
|
||||
Check if a verification email can be resent.
|
||||
|
||||
Returns (can_resend, seconds_until_can_resend)
|
||||
"""
|
||||
if not user.email_verification_sent_at:
|
||||
return True, None
|
||||
|
||||
# Allow resend after 60 seconds
|
||||
cooldown = timedelta(seconds=60)
|
||||
time_since_last = datetime.utcnow() - user.email_verification_sent_at
|
||||
|
||||
if time_since_last >= cooldown:
|
||||
return True, None
|
||||
|
||||
remaining = (cooldown - time_since_last).seconds
|
||||
return False, remaining
|
||||
|
||||
|
||||
def can_resend_password_reset(user) -> tuple[bool, Optional[int]]:
|
||||
"""
|
||||
Check if a password reset email can be resent.
|
||||
|
||||
Returns (can_resend, seconds_until_can_resend)
|
||||
"""
|
||||
if not user.password_reset_sent_at:
|
||||
return True, None
|
||||
|
||||
# Allow resend after 60 seconds
|
||||
cooldown = timedelta(seconds=60)
|
||||
time_since_last = datetime.utcnow() - user.password_reset_sent_at
|
||||
|
||||
if time_since_last >= cooldown:
|
||||
return True, None
|
||||
|
||||
remaining = (cooldown - time_since_last).seconds
|
||||
return False, remaining
|
||||
422
src/services/embeddings.py
Normal file
422
src/services/embeddings.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Embedding generation and semantic search services.
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from flask import current_app
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
EMBEDDINGS_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDDINGS_AVAILABLE = False
|
||||
cosine_similarity = None
|
||||
|
||||
from src.database import db
|
||||
from src.models import Recording, TranscriptChunk, InternalShare, RecordingTag
|
||||
|
||||
ENABLE_INTERNAL_SHARING = os.environ.get('ENABLE_INTERNAL_SHARING', 'false').lower() == 'true'
|
||||
|
||||
# Initialize embedding model (lazy loading)
|
||||
_embedding_model = None
|
||||
|
||||
|
||||
|
||||
def get_embedding_model():
|
||||
"""Get or initialize the sentence transformer model."""
|
||||
global _embedding_model
|
||||
|
||||
if not EMBEDDINGS_AVAILABLE:
|
||||
return None
|
||||
|
||||
if _embedding_model is None:
|
||||
try:
|
||||
_embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
current_app.logger.info("Embedding model loaded successfully")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Failed to load embedding model: {e}")
|
||||
return None
|
||||
return _embedding_model
|
||||
|
||||
|
||||
|
||||
def chunk_transcription(transcription, max_chunk_length=500, overlap=50):
|
||||
"""
|
||||
Split transcription into overlapping chunks for better context retrieval.
|
||||
|
||||
Args:
|
||||
transcription (str): The full transcription text
|
||||
max_chunk_length (int): Maximum characters per chunk
|
||||
overlap (int): Character overlap between chunks
|
||||
|
||||
Returns:
|
||||
list: List of text chunks
|
||||
"""
|
||||
if not transcription or len(transcription) <= max_chunk_length:
|
||||
return [transcription] if transcription else []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(transcription):
|
||||
end = start + max_chunk_length
|
||||
|
||||
# Try to break at sentence boundaries
|
||||
if end < len(transcription):
|
||||
# Look for sentence endings within the last 100 characters
|
||||
sentence_end = -1
|
||||
for i in range(max(0, end - 100), end):
|
||||
if transcription[i] in '.!?':
|
||||
# Check if it's not an abbreviation
|
||||
if i + 1 < len(transcription) and transcription[i + 1].isspace():
|
||||
sentence_end = i + 1
|
||||
|
||||
if sentence_end > start:
|
||||
end = sentence_end
|
||||
|
||||
chunk = transcription[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Move start position with overlap
|
||||
start = max(start + 1, end - overlap)
|
||||
|
||||
# Prevent infinite loop
|
||||
if start >= len(transcription):
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
|
||||
def generate_embeddings(texts):
|
||||
"""
|
||||
Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (list): List of text strings
|
||||
|
||||
Returns:
|
||||
list: List of embedding vectors as numpy arrays, or empty list if embeddings unavailable
|
||||
"""
|
||||
if not EMBEDDINGS_AVAILABLE:
|
||||
current_app.logger.warning("Embeddings not available - skipping embedding generation")
|
||||
return []
|
||||
|
||||
model = get_embedding_model()
|
||||
if not model or not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
embeddings = model.encode(texts)
|
||||
return [embedding.astype(np.float32) for embedding in embeddings]
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error generating embeddings: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def serialize_embedding(embedding):
|
||||
"""Convert numpy array to binary for database storage."""
|
||||
if embedding is None or not EMBEDDINGS_AVAILABLE:
|
||||
return None
|
||||
return embedding.tobytes()
|
||||
|
||||
|
||||
|
||||
def deserialize_embedding(binary_data):
|
||||
"""Convert binary data back to numpy array."""
|
||||
if binary_data is None or not EMBEDDINGS_AVAILABLE:
|
||||
return None
|
||||
return np.frombuffer(binary_data, dtype=np.float32)
|
||||
|
||||
|
||||
|
||||
def get_accessible_recording_ids(user_id):
|
||||
"""
|
||||
Get all recording IDs that a user has access to.
|
||||
|
||||
Includes:
|
||||
- Recordings owned by the user
|
||||
- Recordings shared with the user via InternalShare
|
||||
- Recordings shared via group tags (if team membership exists)
|
||||
|
||||
Args:
|
||||
user_id (int): User ID to check access for
|
||||
|
||||
Returns:
|
||||
list: List of recording IDs the user can access
|
||||
"""
|
||||
accessible_ids = set()
|
||||
|
||||
# 1. User's own recordings
|
||||
own_recordings = db.session.query(Recording.id).filter_by(user_id=user_id).all()
|
||||
accessible_ids.update([r.id for r in own_recordings])
|
||||
|
||||
# 2. Internally shared recordings
|
||||
if ENABLE_INTERNAL_SHARING:
|
||||
shared_recordings = db.session.query(InternalShare.recording_id).filter_by(
|
||||
shared_with_user_id=user_id
|
||||
).all()
|
||||
accessible_ids.update([r.recording_id for r in shared_recordings])
|
||||
|
||||
return list(accessible_ids)
|
||||
|
||||
|
||||
|
||||
def process_recording_chunks(recording_id):
|
||||
"""
|
||||
Process a recording by creating chunks and generating embeddings.
|
||||
This should be called after a recording is transcribed.
|
||||
"""
|
||||
try:
|
||||
recording = db.session.get(Recording, recording_id)
|
||||
if not recording or not recording.transcription:
|
||||
return False
|
||||
|
||||
# Delete existing chunks for this recording
|
||||
TranscriptChunk.query.filter_by(recording_id=recording_id).delete()
|
||||
|
||||
# Create chunks
|
||||
chunks = chunk_transcription(recording.transcription)
|
||||
|
||||
if not chunks:
|
||||
return True
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = generate_embeddings(chunks)
|
||||
|
||||
# Store chunks in database
|
||||
for i, (chunk_text, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk = TranscriptChunk(
|
||||
recording_id=recording_id,
|
||||
user_id=recording.user_id,
|
||||
chunk_index=i,
|
||||
content=chunk_text,
|
||||
embedding=serialize_embedding(embedding) if embedding is not None else None
|
||||
)
|
||||
db.session.add(chunk)
|
||||
|
||||
db.session.commit()
|
||||
current_app.logger.info(f"Created {len(chunks)} chunks for recording {recording_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error processing chunks for recording {recording_id}: {e}")
|
||||
db.session.rollback()
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def basic_text_search_chunks(user_id, query, filters=None, top_k=5):
|
||||
"""
|
||||
Basic text search fallback when embeddings are not available.
|
||||
Uses simple text matching instead of semantic search.
|
||||
Searches across user's own recordings and recordings shared with them.
|
||||
"""
|
||||
try:
|
||||
# Get all accessible recording IDs (own + shared)
|
||||
accessible_recording_ids = get_accessible_recording_ids(user_id)
|
||||
|
||||
if not accessible_recording_ids:
|
||||
return []
|
||||
|
||||
# Build base query for chunks from accessible recordings with eager loading
|
||||
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
|
||||
TranscriptChunk.recording_id.in_(accessible_recording_ids)
|
||||
)
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
if filters.get('tag_ids'):
|
||||
chunks_query = chunks_query.join(Recording).join(
|
||||
RecordingTag, Recording.id == RecordingTag.recording_id
|
||||
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
|
||||
|
||||
if filters.get('speaker_names'):
|
||||
# Filter by participants field in recordings instead of chunk speaker_name
|
||||
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
|
||||
chunks_query = chunks_query.join(Recording)
|
||||
|
||||
# Build OR conditions for each speaker name in participants
|
||||
speaker_conditions = []
|
||||
for speaker_name in filters['speaker_names']:
|
||||
speaker_conditions.append(
|
||||
Recording.participants.ilike(f'%{speaker_name}%')
|
||||
)
|
||||
|
||||
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
|
||||
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
|
||||
|
||||
if filters.get('recording_ids'):
|
||||
chunks_query = chunks_query.filter(
|
||||
TranscriptChunk.recording_id.in_(filters['recording_ids'])
|
||||
)
|
||||
|
||||
if filters.get('date_from') or filters.get('date_to'):
|
||||
chunks_query = chunks_query.join(Recording)
|
||||
if filters.get('date_from'):
|
||||
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
|
||||
if filters.get('date_to'):
|
||||
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
|
||||
|
||||
# Text search - filter stop words and rank by match count
|
||||
stop_words = {'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
|
||||
'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
|
||||
'would', 'could', 'should', 'may', 'might', 'shall', 'can',
|
||||
'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
|
||||
'up', 'about', 'into', 'through', 'during', 'before', 'after',
|
||||
'and', 'but', 'or', 'nor', 'not', 'so', 'yet', 'both',
|
||||
'it', 'its', 'this', 'that', 'these', 'those', 'what', 'which',
|
||||
'who', 'whom', 'how', 'when', 'where', 'why',
|
||||
'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he', 'she',
|
||||
'his', 'her', 'they', 'them', 'their'}
|
||||
|
||||
query_words = [w for w in query.lower().split() if w not in stop_words and len(w) > 1]
|
||||
|
||||
if not query_words:
|
||||
# If all words were stop words, fall back to using original query words
|
||||
query_words = [w for w in query.lower().split() if len(w) > 1]
|
||||
|
||||
if query_words:
|
||||
from sqlalchemy import or_, func, case, literal
|
||||
|
||||
# Filter: match ANY keyword (OR) to get candidates
|
||||
text_conditions = []
|
||||
for word in query_words:
|
||||
text_conditions.append(TranscriptChunk.content.ilike(f'%{word}%'))
|
||||
chunks_query = chunks_query.filter(or_(*text_conditions))
|
||||
|
||||
# Fetch more candidates than needed so we can rank them
|
||||
chunks = chunks_query.limit(top_k * 5).all()
|
||||
|
||||
# Rank by how many query words each chunk matches
|
||||
scored_chunks = []
|
||||
for chunk in chunks:
|
||||
content_lower = chunk.content.lower()
|
||||
match_count = sum(1 for word in query_words if word in content_lower)
|
||||
score = match_count / len(query_words) # 0.0 to 1.0
|
||||
scored_chunks.append((chunk, score))
|
||||
|
||||
# Sort by score descending, take top_k
|
||||
scored_chunks.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored_chunks[:top_k]
|
||||
|
||||
# No usable query words
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in basic text search: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def semantic_search_chunks(user_id, query, filters=None, top_k=5):
|
||||
"""
|
||||
Perform semantic search on transcript chunks with filtering.
|
||||
Searches across user's own recordings and recordings shared with them.
|
||||
|
||||
Args:
|
||||
user_id (int): User ID for permission filtering
|
||||
query (str): Search query
|
||||
filters (dict): Optional filters for tags, speakers, dates, recording_ids
|
||||
top_k (int): Number of top chunks to return
|
||||
|
||||
Returns:
|
||||
list: List of relevant chunks with similarity scores
|
||||
"""
|
||||
try:
|
||||
# If embeddings are not available, fall back to basic text search
|
||||
if not EMBEDDINGS_AVAILABLE:
|
||||
current_app.logger.info("Embeddings not available - using basic text search as fallback")
|
||||
return basic_text_search_chunks(user_id, query, filters, top_k)
|
||||
|
||||
# Generate embedding for the query
|
||||
model = get_embedding_model()
|
||||
if not model:
|
||||
return basic_text_search_chunks(user_id, query, filters, top_k)
|
||||
|
||||
query_embedding = model.encode([query])[0]
|
||||
|
||||
# Get all accessible recording IDs (own + shared)
|
||||
accessible_recording_ids = get_accessible_recording_ids(user_id)
|
||||
|
||||
if not accessible_recording_ids:
|
||||
return []
|
||||
|
||||
# Build base query for chunks from accessible recordings with eager loading
|
||||
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
|
||||
TranscriptChunk.recording_id.in_(accessible_recording_ids)
|
||||
)
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
if filters.get('tag_ids'):
|
||||
# Join with recordings that have specified tags
|
||||
chunks_query = chunks_query.join(Recording).join(
|
||||
RecordingTag, Recording.id == RecordingTag.recording_id
|
||||
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
|
||||
|
||||
if filters.get('speaker_names'):
|
||||
# Filter by participants field in recordings instead of chunk speaker_name
|
||||
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
|
||||
chunks_query = chunks_query.join(Recording)
|
||||
|
||||
# Build OR conditions for each speaker name in participants
|
||||
speaker_conditions = []
|
||||
for speaker_name in filters['speaker_names']:
|
||||
speaker_conditions.append(
|
||||
Recording.participants.ilike(f'%{speaker_name}%')
|
||||
)
|
||||
|
||||
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
|
||||
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
|
||||
|
||||
if filters.get('recording_ids'):
|
||||
chunks_query = chunks_query.filter(
|
||||
TranscriptChunk.recording_id.in_(filters['recording_ids'])
|
||||
)
|
||||
|
||||
if filters.get('date_from') or filters.get('date_to'):
|
||||
chunks_query = chunks_query.join(Recording)
|
||||
if filters.get('date_from'):
|
||||
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
|
||||
if filters.get('date_to'):
|
||||
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
|
||||
|
||||
# Get chunks that have embeddings
|
||||
chunks = chunks_query.filter(TranscriptChunk.embedding.isnot(None)).all()
|
||||
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
# Calculate similarities
|
||||
chunk_similarities = []
|
||||
for chunk in chunks:
|
||||
try:
|
||||
chunk_embedding = deserialize_embedding(chunk.embedding)
|
||||
if chunk_embedding is not None:
|
||||
similarity = cosine_similarity(
|
||||
query_embedding.reshape(1, -1),
|
||||
chunk_embedding.reshape(1, -1)
|
||||
)[0][0]
|
||||
chunk_similarities.append((chunk, float(similarity)))
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Error calculating similarity for chunk {chunk.id}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by similarity and return top k
|
||||
chunk_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
return chunk_similarities[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error in semantic search: {e}")
|
||||
return []
|
||||
|
||||
# --- Helper Functions for Document Processing ---
|
||||
|
||||
|
||||
|
||||
631
src/services/job_queue.py
Normal file
631
src/services/job_queue.py
Normal file
@@ -0,0 +1,631 @@
|
||||
"""
|
||||
Fair database-backed job queue for background processing tasks.
|
||||
|
||||
This queue ensures:
|
||||
- Jobs persist across application restarts
|
||||
- Fair round-robin scheduling between users
|
||||
- Separate queues for transcription (slow) and summary (fast) jobs
|
||||
- Limited concurrency to prevent overwhelming external services
|
||||
- Automatic recovery of orphaned jobs
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
TRANSCRIPTION_WORKERS = int(os.environ.get('JOB_QUEUE_WORKERS', '2'))
|
||||
SUMMARY_WORKERS = int(os.environ.get('SUMMARY_QUEUE_WORKERS', '2'))
|
||||
MAX_RETRIES = int(os.environ.get('JOB_MAX_RETRIES', '3'))
|
||||
POLL_INTERVAL = 1.0 # seconds between checking for new jobs
|
||||
|
||||
# Job type categories
|
||||
TRANSCRIPTION_JOBS = ['transcribe', 'reprocess_transcription']
|
||||
SUMMARY_JOBS = ['summarize', 'reprocess_summary']
|
||||
|
||||
|
||||
class FairJobQueue:
|
||||
"""
|
||||
A database-backed job queue with fair scheduling across users.
|
||||
|
||||
Uses separate queues for transcription and summary jobs to prevent
|
||||
slow transcription jobs from blocking fast summary jobs.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern to ensure only one queue exists."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the job queue."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._transcription_workers = []
|
||||
self._summary_workers = []
|
||||
self._running = False
|
||||
self._app = None
|
||||
# Separate round-robin tracking for each queue
|
||||
self._last_user_id_transcription = None
|
||||
self._last_user_id_summary = None
|
||||
# Lock for claiming jobs (SQLite doesn't support row-level locking)
|
||||
self._claim_lock = threading.Lock()
|
||||
self._initialized = True
|
||||
|
||||
logger.info(f"FairJobQueue initialized: {TRANSCRIPTION_WORKERS} transcription workers, {SUMMARY_WORKERS} summary workers")
|
||||
|
||||
def init_app(self, app):
|
||||
"""Initialize with Flask app for context management."""
|
||||
self._app = app
|
||||
|
||||
@contextmanager
|
||||
def _app_context(self):
|
||||
"""Get application context for database operations."""
|
||||
if self._app:
|
||||
with self._app.app_context():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def start(self):
|
||||
"""Start the worker threads for both queues."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Start transcription workers
|
||||
for i in range(TRANSCRIPTION_WORKERS):
|
||||
worker = threading.Thread(
|
||||
target=self._worker_loop,
|
||||
args=(TRANSCRIPTION_JOBS, 'transcription'),
|
||||
name=f"TranscriptionWorker-{i}",
|
||||
daemon=True
|
||||
)
|
||||
worker.start()
|
||||
self._transcription_workers.append(worker)
|
||||
|
||||
# Start summary workers
|
||||
for i in range(SUMMARY_WORKERS):
|
||||
worker = threading.Thread(
|
||||
target=self._worker_loop,
|
||||
args=(SUMMARY_JOBS, 'summary'),
|
||||
name=f"SummaryWorker-{i}",
|
||||
daemon=True
|
||||
)
|
||||
worker.start()
|
||||
self._summary_workers.append(worker)
|
||||
|
||||
logger.info(f"Started {TRANSCRIPTION_WORKERS} transcription workers and {SUMMARY_WORKERS} summary workers")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the worker threads gracefully."""
|
||||
self._running = False
|
||||
for worker in self._transcription_workers + self._summary_workers:
|
||||
worker.join(timeout=5)
|
||||
self._transcription_workers.clear()
|
||||
self._summary_workers.clear()
|
||||
logger.info("Job queue workers stopped")
|
||||
|
||||
def _worker_loop(self, job_types: List[str], queue_name: str):
|
||||
"""Main worker loop that processes jobs of specific types."""
|
||||
while self._running:
|
||||
try:
|
||||
job = self._claim_next_job(job_types, queue_name)
|
||||
if job:
|
||||
self._process_job(job)
|
||||
else:
|
||||
# No jobs available, sleep briefly
|
||||
time.sleep(POLL_INTERVAL)
|
||||
except Exception as e:
|
||||
logger.error(f"{queue_name.capitalize()} worker error: {e}", exc_info=True)
|
||||
time.sleep(POLL_INTERVAL)
|
||||
|
||||
def _claim_next_job(self, job_types: List[str], queue_name: str):
|
||||
"""
|
||||
Claim the next job of specified types using fair round-robin scheduling.
|
||||
|
||||
Args:
|
||||
job_types: List of job types this worker handles
|
||||
queue_name: Name of the queue ('transcription' or 'summary')
|
||||
|
||||
Returns the claimed job or None if no jobs available.
|
||||
"""
|
||||
# Use lock to prevent race conditions (SQLite doesn't support row-level locking)
|
||||
with self._claim_lock:
|
||||
with self._app_context():
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob
|
||||
|
||||
try:
|
||||
# Get list of users with queued jobs of our types
|
||||
users_with_jobs = db.session.query(
|
||||
ProcessingJob.user_id
|
||||
).filter(
|
||||
ProcessingJob.status == 'queued',
|
||||
ProcessingJob.job_type.in_(job_types)
|
||||
).group_by(
|
||||
ProcessingJob.user_id
|
||||
).order_by(
|
||||
db.func.min(ProcessingJob.created_at)
|
||||
).all()
|
||||
|
||||
if not users_with_jobs:
|
||||
return None
|
||||
|
||||
user_ids = [u[0] for u in users_with_jobs]
|
||||
|
||||
# Get last user ID for this queue type
|
||||
last_user_id = (self._last_user_id_transcription
|
||||
if queue_name == 'transcription'
|
||||
else self._last_user_id_summary)
|
||||
|
||||
# Round-robin: pick next user after last processed
|
||||
next_user_id = None
|
||||
if last_user_id is not None and last_user_id in user_ids:
|
||||
idx = user_ids.index(last_user_id)
|
||||
next_user_id = user_ids[(idx + 1) % len(user_ids)]
|
||||
else:
|
||||
next_user_id = user_ids[0]
|
||||
|
||||
# Get oldest queued job of our types for this user
|
||||
candidate_job = ProcessingJob.query.filter(
|
||||
ProcessingJob.user_id == next_user_id,
|
||||
ProcessingJob.status == 'queued',
|
||||
ProcessingJob.job_type.in_(job_types)
|
||||
).order_by(
|
||||
ProcessingJob.created_at
|
||||
).first()
|
||||
|
||||
if candidate_job:
|
||||
# Atomically claim the job - only succeeds if status is still 'queued'
|
||||
# This prevents race conditions when multiple workers try to claim the same job
|
||||
from sqlalchemy import update
|
||||
claim_time = datetime.utcnow()
|
||||
result = db.session.execute(
|
||||
update(ProcessingJob)
|
||||
.where(
|
||||
ProcessingJob.id == candidate_job.id,
|
||||
ProcessingJob.status == 'queued' # Critical: only claim if still queued
|
||||
)
|
||||
.values(status='processing', started_at=claim_time)
|
||||
)
|
||||
|
||||
if result.rowcount == 0:
|
||||
# Job was already claimed by another worker - this is expected with multiple workers
|
||||
logger.debug(f"[{queue_name.upper()}] Job {candidate_job.id} already claimed by another worker")
|
||||
db.session.rollback()
|
||||
return None
|
||||
|
||||
# Also update Recording.status to reflect active processing
|
||||
from src.models import Recording
|
||||
recording = db.session.get(Recording, candidate_job.recording_id)
|
||||
if recording and recording.status == 'QUEUED':
|
||||
recording.status = 'PROCESSING'
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Refresh the job object to get updated values
|
||||
db.session.refresh(candidate_job)
|
||||
|
||||
# Update last user ID for this queue
|
||||
if queue_name == 'transcription':
|
||||
self._last_user_id_transcription = next_user_id
|
||||
else:
|
||||
self._last_user_id_summary = next_user_id
|
||||
|
||||
wait_time = (claim_time - candidate_job.created_at).total_seconds()
|
||||
logger.info(f"[{queue_name.upper()}] Claimed job {candidate_job.id} (type={candidate_job.job_type}) for user {candidate_job.user_id}, recording {candidate_job.recording_id} (waited {wait_time:.1f}s)")
|
||||
return candidate_job
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error claiming {queue_name} job: {e}", exc_info=True)
|
||||
db.session.rollback()
|
||||
return None
|
||||
|
||||
def _is_permanent_error(self, error_str: str) -> bool:
|
||||
"""
|
||||
Detect if an error is permanent and should not be retried.
|
||||
|
||||
Permanent errors include:
|
||||
- 400: Bad request (invalid format, invalid parameters)
|
||||
- 413: File too large (user needs to enable chunking or compress file)
|
||||
- 401/403: Authentication/authorization errors (credentials issue)
|
||||
- 402: Payment required (billing issue)
|
||||
- 404: Resource not found (model doesn't exist)
|
||||
- Invalid format errors (file needs to be converted)
|
||||
"""
|
||||
error_lower = error_str.lower()
|
||||
|
||||
# HTTP status codes that indicate permanent errors
|
||||
permanent_codes = ['400', '413', '401', '402', '403', '404']
|
||||
for code in permanent_codes:
|
||||
if f'error code: {code}' in error_lower or f'status {code}' in error_lower:
|
||||
return True
|
||||
|
||||
# Specific error patterns that are permanent (simple substring matching)
|
||||
permanent_patterns = [
|
||||
'maximum content size limit',
|
||||
'file too large',
|
||||
'payload too large',
|
||||
'invalid api key',
|
||||
'incorrect api key',
|
||||
'authentication failed',
|
||||
'unauthorized',
|
||||
'permission denied',
|
||||
'access denied',
|
||||
'billing',
|
||||
'payment required',
|
||||
'quota exceeded',
|
||||
'insufficient funds',
|
||||
'model not found',
|
||||
'invalid model',
|
||||
'unsupported format',
|
||||
'invalid file format',
|
||||
'invalid_request_error',
|
||||
'bad request',
|
||||
]
|
||||
|
||||
for pattern in permanent_patterns:
|
||||
if pattern in error_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _process_job(self, job):
|
||||
"""Process a single job by dispatching to the appropriate task function."""
|
||||
job_id = job.id
|
||||
job_type = job.job_type
|
||||
recording_id = job.recording_id
|
||||
params_str = job.params
|
||||
is_new_upload = job.is_new_upload
|
||||
|
||||
with self._app_context():
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob, Recording
|
||||
from flask import current_app
|
||||
|
||||
try:
|
||||
# Parse job parameters
|
||||
params = json.loads(params_str) if params_str else {}
|
||||
|
||||
# Re-fetch the job in this session context to ensure it's attached
|
||||
job = db.session.get(ProcessingJob, job_id)
|
||||
if not job:
|
||||
logger.error(f"Job {job_id} not found when trying to process")
|
||||
return
|
||||
|
||||
# Get recording
|
||||
recording = db.session.get(Recording, recording_id)
|
||||
if not recording:
|
||||
raise ValueError(f"Recording {recording_id} not found")
|
||||
|
||||
# Dispatch based on job type
|
||||
if job_type == 'transcribe':
|
||||
self._run_transcription(job, recording, params)
|
||||
elif job_type == 'summarize':
|
||||
self._run_summarization(job, recording, params)
|
||||
elif job_type == 'reprocess_transcription':
|
||||
self._run_reprocess_transcription(job, recording, params)
|
||||
elif job_type == 'reprocess_summary':
|
||||
self._run_reprocess_summary(job, recording, params)
|
||||
else:
|
||||
raise ValueError(f"Unknown job type: {job_type}")
|
||||
|
||||
# Mark as completed - re-fetch to ensure we have latest state
|
||||
job = db.session.get(ProcessingJob, job_id)
|
||||
if job:
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
|
||||
|
||||
# Check if this is a permanent error that shouldn't be retried
|
||||
is_permanent_error = self._is_permanent_error(error_str)
|
||||
|
||||
# Re-fetch job to update it
|
||||
job = db.session.get(ProcessingJob, job_id)
|
||||
if job:
|
||||
job.error_message = error_str
|
||||
job.retry_count += 1
|
||||
|
||||
# Only retry if: not a permanent error AND under retry limit
|
||||
if not is_permanent_error and job.retry_count < MAX_RETRIES:
|
||||
# Re-queue for retry
|
||||
job.status = 'queued'
|
||||
job.started_at = None
|
||||
logger.info(f"Job {job_id} re-queued for retry ({job.retry_count}/{MAX_RETRIES})")
|
||||
else:
|
||||
job.status = 'failed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
recording = db.session.get(Recording, recording_id)
|
||||
|
||||
if is_permanent_error:
|
||||
logger.info(f"Job {job_id} failed with permanent error (no retry): {error_str[:100]}")
|
||||
|
||||
# Always keep recordings with FAILED status so users can see the error
|
||||
# and reprocess later (e.g., when ASR server recovers)
|
||||
if recording:
|
||||
# Keep the recording with FAILED status so user can see the error and fix settings
|
||||
recording.status = 'FAILED'
|
||||
# Format the error for nice display
|
||||
from src.utils.error_formatting import format_error_for_storage
|
||||
recording.transcription = format_error_for_storage(error_str)
|
||||
|
||||
if is_permanent_error:
|
||||
logger.error(f"Job {job_id} failed permanently (non-retryable error)")
|
||||
else:
|
||||
logger.error(f"Job {job_id} failed permanently after {MAX_RETRIES} retries")
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _run_transcription(self, job, recording, params):
|
||||
"""Run transcription task. Status updates handled by task function."""
|
||||
from src.tasks.processing import transcribe_audio_task
|
||||
from flask import current_app
|
||||
|
||||
filepath = recording.audio_path
|
||||
filename_for_asr = recording.original_filename or os.path.basename(filepath)
|
||||
|
||||
transcribe_audio_task(
|
||||
current_app._get_current_object().app_context(),
|
||||
recording.id,
|
||||
filepath,
|
||||
filename_for_asr,
|
||||
datetime.utcnow(),
|
||||
language=params.get('language'),
|
||||
min_speakers=params.get('min_speakers'),
|
||||
max_speakers=params.get('max_speakers'),
|
||||
tag_id=params.get('tag_id'),
|
||||
hotwords=params.get('hotwords'),
|
||||
initial_prompt=params.get('initial_prompt'),
|
||||
)
|
||||
|
||||
def _run_summarization(self, job, recording, params):
|
||||
"""Run summarization-only task. Status updates handled by task function."""
|
||||
from src.tasks.processing import generate_summary_only_task
|
||||
from flask import current_app
|
||||
|
||||
generate_summary_only_task(
|
||||
current_app._get_current_object().app_context(),
|
||||
recording.id,
|
||||
custom_prompt_override=params.get('custom_prompt'),
|
||||
user_id=params.get('user_id')
|
||||
)
|
||||
|
||||
def _run_reprocess_transcription(self, job, recording, params):
|
||||
"""Run transcription reprocessing task. Status updates handled by task function."""
|
||||
from src.tasks.processing import transcribe_audio_task
|
||||
from flask import current_app
|
||||
|
||||
filepath = recording.audio_path
|
||||
filename_for_asr = recording.original_filename or os.path.basename(filepath)
|
||||
|
||||
transcribe_audio_task(
|
||||
current_app._get_current_object().app_context(),
|
||||
recording.id,
|
||||
filepath,
|
||||
filename_for_asr,
|
||||
datetime.utcnow(),
|
||||
language=params.get('language'),
|
||||
min_speakers=params.get('min_speakers'),
|
||||
max_speakers=params.get('max_speakers'),
|
||||
tag_id=params.get('tag_id'),
|
||||
hotwords=params.get('hotwords'),
|
||||
initial_prompt=params.get('initial_prompt'),
|
||||
)
|
||||
|
||||
def _run_reprocess_summary(self, job, recording, params):
|
||||
"""Run summary reprocessing task. Status updates handled by task function."""
|
||||
from src.tasks.processing import generate_summary_only_task
|
||||
from flask import current_app
|
||||
|
||||
generate_summary_only_task(
|
||||
current_app._get_current_object().app_context(),
|
||||
recording.id,
|
||||
custom_prompt_override=params.get('custom_prompt'),
|
||||
user_id=params.get('user_id')
|
||||
)
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
user_id: int,
|
||||
recording_id: int,
|
||||
job_type: str,
|
||||
params: Dict[str, Any] = None,
|
||||
is_new_upload: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Add a job to the database queue.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user who owns this job
|
||||
recording_id: ID of the recording to process
|
||||
job_type: Type of job (transcribe, summarize, reprocess_transcription, reprocess_summary)
|
||||
params: Optional parameters for the job
|
||||
is_new_upload: True if this is a new file upload (for cleanup on failure)
|
||||
|
||||
Returns:
|
||||
The created job ID
|
||||
"""
|
||||
with self._app_context():
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob, Recording
|
||||
|
||||
# Check for existing active job of the SAME TYPE for this recording
|
||||
# Allow different job types to coexist (e.g., transcribe and summarize)
|
||||
existing = ProcessingJob.query.filter(
|
||||
ProcessingJob.recording_id == recording_id,
|
||||
ProcessingJob.job_type == job_type,
|
||||
ProcessingJob.status.in_(['queued', 'processing'])
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
logger.warning(f"Job of type {job_type} already exists for recording {recording_id}: {existing.id}")
|
||||
return existing.id
|
||||
|
||||
# Create new job
|
||||
job = ProcessingJob(
|
||||
user_id=user_id,
|
||||
recording_id=recording_id,
|
||||
job_type=job_type,
|
||||
params=json.dumps(params) if params else None,
|
||||
is_new_upload=is_new_upload
|
||||
)
|
||||
db.session.add(job)
|
||||
|
||||
# Update recording status based on job type
|
||||
recording = db.session.get(Recording, recording_id)
|
||||
if recording:
|
||||
if job_type in SUMMARY_JOBS:
|
||||
recording.status = 'SUMMARIZING'
|
||||
else:
|
||||
recording.status = 'QUEUED'
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Auto-start workers if not running
|
||||
if not self._running:
|
||||
self.start()
|
||||
|
||||
queue_name = 'summary' if job_type in SUMMARY_JOBS else 'transcription'
|
||||
logger.info(f"Enqueued {queue_name} job {job.id} (type={job_type}) for user {user_id}, recording {recording_id}")
|
||||
return job.id
|
||||
|
||||
def recover_orphaned_jobs(self):
|
||||
"""
|
||||
Recover jobs that were processing when the app crashed.
|
||||
Call this on startup to reset orphaned jobs back to queued.
|
||||
"""
|
||||
with self._app_context():
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob
|
||||
|
||||
orphaned = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'processing'
|
||||
).all()
|
||||
|
||||
for job in orphaned:
|
||||
job.status = 'queued'
|
||||
job.started_at = None
|
||||
queue_name = 'summary' if job.job_type in SUMMARY_JOBS else 'transcription'
|
||||
logger.info(f"Recovered orphaned {queue_name} job {job.id} for recording {job.recording_id}")
|
||||
|
||||
if orphaned:
|
||||
db.session.commit()
|
||||
logger.info(f"Recovered {len(orphaned)} orphaned jobs")
|
||||
|
||||
def get_queue_status(self) -> Dict[str, Any]:
|
||||
"""Get the current queue status for both queues."""
|
||||
with self._app_context():
|
||||
from src.models import ProcessingJob
|
||||
|
||||
transcription_queued = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'queued',
|
||||
ProcessingJob.job_type.in_(TRANSCRIPTION_JOBS)
|
||||
).count()
|
||||
transcription_processing = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'processing',
|
||||
ProcessingJob.job_type.in_(TRANSCRIPTION_JOBS)
|
||||
).count()
|
||||
|
||||
summary_queued = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'queued',
|
||||
ProcessingJob.job_type.in_(SUMMARY_JOBS)
|
||||
).count()
|
||||
summary_processing = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'processing',
|
||||
ProcessingJob.job_type.in_(SUMMARY_JOBS)
|
||||
).count()
|
||||
|
||||
return {
|
||||
"transcription_queue": {
|
||||
"queued": transcription_queued,
|
||||
"processing": transcription_processing,
|
||||
"workers": TRANSCRIPTION_WORKERS
|
||||
},
|
||||
"summary_queue": {
|
||||
"queued": summary_queued,
|
||||
"processing": summary_processing,
|
||||
"workers": SUMMARY_WORKERS
|
||||
},
|
||||
"is_running": self._running
|
||||
}
|
||||
|
||||
def get_position_in_queue(self, recording_id: int) -> Optional[int]:
|
||||
"""Get the position of a recording's job in its respective queue (1-indexed)."""
|
||||
with self._app_context():
|
||||
from src.models import ProcessingJob
|
||||
|
||||
job = ProcessingJob.query.filter(
|
||||
ProcessingJob.recording_id == recording_id,
|
||||
ProcessingJob.status == 'queued'
|
||||
).first()
|
||||
|
||||
if not job:
|
||||
return None
|
||||
|
||||
# Determine which queue this job is in
|
||||
job_types = SUMMARY_JOBS if job.job_type in SUMMARY_JOBS else TRANSCRIPTION_JOBS
|
||||
|
||||
# Count jobs of the same type created before this one
|
||||
position = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'queued',
|
||||
ProcessingJob.job_type.in_(job_types),
|
||||
ProcessingJob.created_at < job.created_at
|
||||
).count() + 1
|
||||
|
||||
return position
|
||||
|
||||
def get_job_for_recording(self, recording_id: int):
|
||||
"""Get the active job for a recording."""
|
||||
with self._app_context():
|
||||
from src.models import ProcessingJob
|
||||
|
||||
return ProcessingJob.query.filter(
|
||||
ProcessingJob.recording_id == recording_id,
|
||||
ProcessingJob.status.in_(['queued', 'processing'])
|
||||
).first()
|
||||
|
||||
def cleanup_old_jobs(self, max_age_hours: int = 24):
|
||||
"""Remove completed/failed jobs older than max_age_hours."""
|
||||
with self._app_context():
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(hours=max_age_hours)
|
||||
|
||||
deleted = ProcessingJob.query.filter(
|
||||
ProcessingJob.status.in_(['completed', 'failed']),
|
||||
ProcessingJob.completed_at < cutoff
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
if deleted:
|
||||
db.session.commit()
|
||||
logger.info(f"Cleaned up {deleted} old jobs")
|
||||
|
||||
|
||||
# Global job queue instance
|
||||
job_queue = FairJobQueue()
|
||||
498
src/services/llm.py
Normal file
498
src/services/llm.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
LLM API integration services (OpenAI/OpenRouter).
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
# Use standard logging instead of current_app.logger for context independence
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenBudgetExceeded(Exception):
|
||||
"""Raised when user exceeds their token budget."""
|
||||
def __init__(self, message, usage_percentage=100):
|
||||
self.message = message
|
||||
self.usage_percentage = usage_percentage
|
||||
super().__init__(message)
|
||||
|
||||
from src.utils import safe_json_loads, extract_json_object
|
||||
|
||||
# Configuration - use TEXT_MODEL_* variables for LLM
|
||||
TEXT_MODEL_API_KEY = os.environ.get("TEXT_MODEL_API_KEY")
|
||||
TEXT_MODEL_BASE_URL = os.environ.get("TEXT_MODEL_BASE_URL", "https://openrouter.ai/api/v1")
|
||||
if TEXT_MODEL_BASE_URL:
|
||||
TEXT_MODEL_BASE_URL = TEXT_MODEL_BASE_URL.split('#')[0].strip()
|
||||
TEXT_MODEL_NAME = os.environ.get("TEXT_MODEL_NAME", "openai/gpt-3.5-turbo")
|
||||
|
||||
# Chat model configuration (optional - falls back to TEXT_MODEL_* if not set)
|
||||
CHAT_MODEL_API_KEY = os.environ.get("CHAT_MODEL_API_KEY")
|
||||
CHAT_MODEL_BASE_URL = os.environ.get("CHAT_MODEL_BASE_URL")
|
||||
if CHAT_MODEL_BASE_URL:
|
||||
CHAT_MODEL_BASE_URL = CHAT_MODEL_BASE_URL.split('#')[0].strip()
|
||||
CHAT_MODEL_NAME = os.environ.get("CHAT_MODEL_NAME")
|
||||
|
||||
# Chat-specific GPT-5 settings (optional - falls back to main GPT5_* settings)
|
||||
CHAT_GPT5_REASONING_EFFORT = os.environ.get("CHAT_GPT5_REASONING_EFFORT")
|
||||
CHAT_GPT5_VERBOSITY = os.environ.get("CHAT_GPT5_VERBOSITY")
|
||||
|
||||
# Streaming options - disable for LLM servers that don't support OpenAI's stream_options
|
||||
ENABLE_STREAM_OPTIONS = os.environ.get("ENABLE_STREAM_OPTIONS", "true").lower() == "true"
|
||||
|
||||
|
||||
def get_chat_config():
|
||||
"""
|
||||
Get chat model configuration, falling back to TEXT_MODEL if not set.
|
||||
|
||||
Returns a dict with api_key, base_url, model_name, and GPT-5 settings.
|
||||
"""
|
||||
if CHAT_MODEL_API_KEY and CHAT_MODEL_NAME:
|
||||
return {
|
||||
'api_key': CHAT_MODEL_API_KEY,
|
||||
'base_url': CHAT_MODEL_BASE_URL or TEXT_MODEL_BASE_URL,
|
||||
'model_name': CHAT_MODEL_NAME,
|
||||
'gpt5_reasoning_effort': CHAT_GPT5_REASONING_EFFORT or os.environ.get("GPT5_REASONING_EFFORT", "medium"),
|
||||
'gpt5_verbosity': CHAT_GPT5_VERBOSITY or os.environ.get("GPT5_VERBOSITY", "medium")
|
||||
}
|
||||
return {
|
||||
'api_key': TEXT_MODEL_API_KEY,
|
||||
'base_url': TEXT_MODEL_BASE_URL,
|
||||
'model_name': TEXT_MODEL_NAME,
|
||||
'gpt5_reasoning_effort': os.environ.get("GPT5_REASONING_EFFORT", "medium"),
|
||||
'gpt5_verbosity': os.environ.get("GPT5_VERBOSITY", "medium")
|
||||
}
|
||||
|
||||
|
||||
# Set up HTTP client with custom headers for OpenRouter app identification
|
||||
app_headers = {
|
||||
"HTTP-Referer": "https://github.com/murtaza-nasir/speakr",
|
||||
"X-Title": "Speakr - AI Audio Transcription",
|
||||
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
|
||||
}
|
||||
|
||||
http_client_no_proxy = httpx.Client(
|
||||
verify=True,
|
||||
headers=app_headers
|
||||
)
|
||||
|
||||
# Create client with placeholder key if not provided (allows app to start)
|
||||
try:
|
||||
api_key = TEXT_MODEL_API_KEY or "not-needed"
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=TEXT_MODEL_BASE_URL,
|
||||
http_client=http_client_no_proxy
|
||||
)
|
||||
except Exception as client_init_e:
|
||||
client = None
|
||||
|
||||
# Create chat client (may be same as main client if no separate config)
|
||||
chat_client = None
|
||||
try:
|
||||
chat_config = get_chat_config()
|
||||
if chat_config['api_key']:
|
||||
if CHAT_MODEL_API_KEY and CHAT_MODEL_API_KEY != TEXT_MODEL_API_KEY:
|
||||
# Separate chat configuration - create dedicated client
|
||||
chat_client = OpenAI(
|
||||
api_key=chat_config['api_key'],
|
||||
base_url=chat_config['base_url'],
|
||||
http_client=http_client_no_proxy
|
||||
)
|
||||
logger.info(f"Separate chat client initialized: {chat_config['base_url']} / {chat_config['model_name']}")
|
||||
else:
|
||||
# Use same client as main LLM
|
||||
chat_client = client
|
||||
except Exception as chat_client_init_e:
|
||||
logger.warning(f"Failed to initialize chat client, falling back to main client: {chat_client_init_e}")
|
||||
chat_client = client
|
||||
|
||||
|
||||
def is_gpt5_model(model_name):
|
||||
"""
|
||||
Check if the model is a GPT-5 series model that requires special API parameters.
|
||||
|
||||
Args:
|
||||
model_name: The model name string
|
||||
|
||||
Returns:
|
||||
Boolean indicating if this is a GPT-5 model
|
||||
"""
|
||||
if not model_name:
|
||||
return False
|
||||
model_lower = model_name.lower()
|
||||
return model_lower.startswith('gpt-5') or model_lower in ['gpt-5', 'gpt-5-mini', 'gpt-5-nano', 'gpt-5-chat-latest']
|
||||
|
||||
|
||||
|
||||
def is_using_openai_api():
|
||||
"""
|
||||
Check if we're using the official OpenAI API (not OpenRouter or other providers).
|
||||
|
||||
Returns:
|
||||
Boolean indicating if this is the OpenAI API
|
||||
"""
|
||||
return TEXT_MODEL_BASE_URL and 'api.openai.com' in TEXT_MODEL_BASE_URL
|
||||
|
||||
|
||||
|
||||
def call_llm_completion(messages, temperature=0.7, response_format=None, stream=False, max_tokens=None,
|
||||
user_id=None, operation_type=None):
|
||||
"""
|
||||
Centralized function for LLM API calls with proper error handling and logging.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
temperature: Sampling temperature (0-1) - ignored for GPT-5 models
|
||||
response_format: Optional response format dict (e.g., {"type": "json_object"})
|
||||
stream: Whether to stream the response
|
||||
max_tokens: Optional maximum tokens to generate
|
||||
user_id: Optional user ID for token tracking and budget enforcement
|
||||
operation_type: Optional operation type for token tracking (e.g., 'summarization', 'chat')
|
||||
|
||||
Returns:
|
||||
OpenAI completion object or generator (if streaming)
|
||||
"""
|
||||
if not client:
|
||||
raise ValueError("LLM client not initialized")
|
||||
|
||||
if not TEXT_MODEL_API_KEY:
|
||||
raise ValueError("TEXT_MODEL_API_KEY not configured")
|
||||
|
||||
# Check budget before making the call
|
||||
if user_id and operation_type:
|
||||
try:
|
||||
from src.services.token_tracking import token_tracker
|
||||
can_proceed, usage_pct, msg = token_tracker.check_budget(user_id)
|
||||
if not can_proceed:
|
||||
raise TokenBudgetExceeded(msg, usage_pct)
|
||||
if usage_pct >= 80:
|
||||
logger.warning(f"User {user_id} at {usage_pct:.1f}% of token budget")
|
||||
except TokenBudgetExceeded:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log but don't block on budget check errors
|
||||
logger.warning(f"Budget check failed for user {user_id}: {e}")
|
||||
|
||||
try:
|
||||
# Check if we're using GPT-5 with OpenAI API
|
||||
using_gpt5 = is_gpt5_model(TEXT_MODEL_NAME) and is_using_openai_api()
|
||||
|
||||
completion_args = {
|
||||
"model": TEXT_MODEL_NAME,
|
||||
"messages": messages,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add stream_options to get usage in final chunk for streaming
|
||||
# Some LLM servers don't support this OpenAI-specific option
|
||||
if stream and ENABLE_STREAM_OPTIONS:
|
||||
completion_args["stream_options"] = {"include_usage": True}
|
||||
|
||||
if using_gpt5:
|
||||
# GPT-5 models don't support temperature, top_p, or logprobs
|
||||
# They use reasoning_effort and verbosity instead
|
||||
logger.debug(f"Using GPT-5 model: {TEXT_MODEL_NAME} - applying GPT-5 specific parameters")
|
||||
|
||||
# Get GPT-5 specific parameters from environment variables
|
||||
reasoning_effort = os.environ.get("GPT5_REASONING_EFFORT", "medium") # minimal, low, medium, high
|
||||
verbosity = os.environ.get("GPT5_VERBOSITY", "medium") # low, medium, high
|
||||
|
||||
# Add GPT-5 specific parameters
|
||||
completion_args["reasoning_effort"] = reasoning_effort
|
||||
completion_args["verbosity"] = verbosity
|
||||
|
||||
# Use max_completion_tokens instead of max_tokens for GPT-5
|
||||
if max_tokens:
|
||||
completion_args["max_completion_tokens"] = max_tokens
|
||||
else:
|
||||
# Non-GPT-5 models use standard parameters
|
||||
completion_args["temperature"] = temperature
|
||||
|
||||
if max_tokens:
|
||||
completion_args["max_tokens"] = max_tokens
|
||||
|
||||
if response_format:
|
||||
completion_args["response_format"] = response_format
|
||||
|
||||
response = client.chat.completions.create(**completion_args)
|
||||
|
||||
# Track usage for non-streaming calls
|
||||
if user_id and operation_type and not stream and response.usage:
|
||||
try:
|
||||
from src.services.token_tracking import token_tracker
|
||||
token_tracker.record_usage(
|
||||
user_id=user_id,
|
||||
operation_type=operation_type,
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model_name=TEXT_MODEL_NAME,
|
||||
cost=getattr(response.usage, 'cost', None)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record token usage: {e}")
|
||||
|
||||
# Debug log for empty responses
|
||||
if not stream and response.choices:
|
||||
content = response.choices[0].message.content
|
||||
if not content:
|
||||
logger.warning(f"LLM returned empty content. Model: {TEXT_MODEL_NAME}, finish_reason: {response.choices[0].finish_reason}")
|
||||
# Log more details if available
|
||||
if hasattr(response.choices[0].message, 'refusal'):
|
||||
logger.warning(f"Refusal: {response.choices[0].message.refusal}")
|
||||
if hasattr(response.choices[0].message, 'tool_calls') and response.choices[0].message.tool_calls:
|
||||
logger.warning(f"Tool calls present: {response.choices[0].message.tool_calls}")
|
||||
|
||||
return response
|
||||
|
||||
except TokenBudgetExceeded:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"LLM API call failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call_chat_completion(messages, temperature=0.7, response_format=None, stream=False, max_tokens=None,
|
||||
user_id=None, operation_type=None):
|
||||
"""
|
||||
Chat-specific LLM completion function. Uses dedicated chat model if configured,
|
||||
otherwise falls back to standard TEXT_MODEL configuration.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
temperature: Sampling temperature (0-1) - ignored for GPT-5 models
|
||||
response_format: Optional response format dict (e.g., {"type": "json_object"})
|
||||
stream: Whether to stream the response
|
||||
max_tokens: Optional maximum tokens to generate
|
||||
user_id: Optional user ID for token tracking and budget enforcement
|
||||
operation_type: Optional operation type for token tracking (e.g., 'chat')
|
||||
|
||||
Returns:
|
||||
OpenAI completion object or generator (if streaming)
|
||||
"""
|
||||
effective_client = chat_client if chat_client else client
|
||||
chat_config = get_chat_config()
|
||||
|
||||
if not effective_client:
|
||||
raise ValueError("Chat LLM client not initialized")
|
||||
|
||||
if not chat_config['api_key']:
|
||||
raise ValueError("Chat model API key not configured")
|
||||
|
||||
# Check budget before making the call
|
||||
if user_id and operation_type:
|
||||
try:
|
||||
from src.services.token_tracking import token_tracker
|
||||
can_proceed, usage_pct, msg = token_tracker.check_budget(user_id)
|
||||
if not can_proceed:
|
||||
raise TokenBudgetExceeded(msg, usage_pct)
|
||||
if usage_pct >= 80:
|
||||
logger.warning(f"User {user_id} at {usage_pct:.1f}% of token budget")
|
||||
except TokenBudgetExceeded:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log but don't block on budget check errors
|
||||
logger.warning(f"Budget check failed for user {user_id}: {e}")
|
||||
|
||||
try:
|
||||
model_name = chat_config['model_name']
|
||||
base_url = chat_config['base_url'] or ''
|
||||
|
||||
# Check if we're using GPT-5 with OpenAI API
|
||||
using_gpt5 = is_gpt5_model(model_name) and 'api.openai.com' in base_url
|
||||
|
||||
completion_args = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add stream_options to get usage in final chunk for streaming
|
||||
# Some LLM servers don't support this OpenAI-specific option
|
||||
if stream and ENABLE_STREAM_OPTIONS:
|
||||
completion_args["stream_options"] = {"include_usage": True}
|
||||
|
||||
if using_gpt5:
|
||||
logger.debug(f"Using GPT-5 chat model: {model_name}")
|
||||
# Use chat-specific GPT-5 settings from config
|
||||
completion_args["reasoning_effort"] = chat_config['gpt5_reasoning_effort']
|
||||
completion_args["verbosity"] = chat_config['gpt5_verbosity']
|
||||
|
||||
if max_tokens:
|
||||
completion_args["max_completion_tokens"] = max_tokens
|
||||
else:
|
||||
completion_args["temperature"] = temperature
|
||||
if max_tokens:
|
||||
completion_args["max_tokens"] = max_tokens
|
||||
|
||||
if response_format:
|
||||
completion_args["response_format"] = response_format
|
||||
|
||||
response = effective_client.chat.completions.create(**completion_args)
|
||||
|
||||
# Track usage for non-streaming calls
|
||||
if user_id and operation_type and not stream and response.usage:
|
||||
try:
|
||||
from src.services.token_tracking import token_tracker
|
||||
token_tracker.record_usage(
|
||||
user_id=user_id,
|
||||
operation_type=operation_type,
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model_name=model_name,
|
||||
cost=getattr(response.usage, 'cost', None)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record token usage: {e}")
|
||||
|
||||
# Debug log for empty responses
|
||||
if not stream and response.choices:
|
||||
content = response.choices[0].message.content
|
||||
if not content:
|
||||
logger.warning(f"Chat LLM returned empty content. Model: {model_name}, finish_reason: {response.choices[0].finish_reason}")
|
||||
|
||||
return response
|
||||
|
||||
except TokenBudgetExceeded:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Chat LLM API call failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def format_api_error_message(error_str):
|
||||
"""
|
||||
Formats API error messages to be more user-friendly.
|
||||
Specifically handles token limit errors with helpful suggestions.
|
||||
"""
|
||||
error_lower = error_str.lower()
|
||||
|
||||
# Check for token limit errors
|
||||
if 'maximum context length' in error_lower and 'tokens' in error_lower:
|
||||
return "[Summary generation failed: The transcription is too long for AI processing. Request your admin to try using a different LLM with a larger context size, or set a limit for the transcript_length_limit in the system settings.]"
|
||||
|
||||
# Check for other common API errors
|
||||
if 'rate limit' in error_lower:
|
||||
return "[Summary generation failed: API rate limit exceeded. Please try again in a few minutes.]"
|
||||
|
||||
if 'insufficient funds' in error_lower or 'quota exceeded' in error_lower:
|
||||
return "[Summary generation failed: API quota exceeded. Please contact support.]"
|
||||
|
||||
if 'timeout' in error_lower:
|
||||
return "[Summary generation failed: Request timed out. Please try again.]"
|
||||
|
||||
# For other errors, show a generic message
|
||||
return f"[Summary generation failed: {error_str}]"
|
||||
|
||||
|
||||
def process_streaming_with_thinking(stream, user_id=None, operation_type=None, model_name=None, app=None):
|
||||
"""
|
||||
Generator that processes a streaming response and separates thinking content.
|
||||
Yields SSE-formatted data with 'delta' for regular content and 'thinking' for thinking content.
|
||||
|
||||
Args:
|
||||
stream: The streaming response from the LLM API
|
||||
user_id: Optional user ID for token tracking
|
||||
operation_type: Optional operation type for token tracking
|
||||
model_name: Optional model name for token tracking
|
||||
app: Optional Flask app instance for database context in generators
|
||||
"""
|
||||
content_buffer = ""
|
||||
in_thinking = False
|
||||
thinking_buffer = ""
|
||||
|
||||
for chunk in stream:
|
||||
# Check for usage in final chunk (from stream_options={'include_usage': True})
|
||||
if hasattr(chunk, 'usage') and chunk.usage and user_id and operation_type:
|
||||
try:
|
||||
from src.services.token_tracking import token_tracker
|
||||
# Use app context if provided (needed for generators where context may be lost)
|
||||
if app:
|
||||
with app.app_context():
|
||||
token_tracker.record_usage(
|
||||
user_id=user_id,
|
||||
operation_type=operation_type,
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model_name=model_name or TEXT_MODEL_NAME,
|
||||
cost=getattr(chunk.usage, 'cost', None)
|
||||
)
|
||||
else:
|
||||
token_tracker.record_usage(
|
||||
user_id=user_id,
|
||||
operation_type=operation_type,
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model_name=model_name or TEXT_MODEL_NAME,
|
||||
cost=getattr(chunk.usage, 'cost', None)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record streaming token usage: {e}")
|
||||
|
||||
# Process content delta
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
content_buffer += content
|
||||
|
||||
# Process the buffer to detect and handle thinking tags
|
||||
while True:
|
||||
if not in_thinking:
|
||||
# Look for opening thinking tag
|
||||
think_start = re.search(r'<think(?:ing)?>', content_buffer, re.IGNORECASE)
|
||||
if think_start:
|
||||
# Send any content before the thinking tag
|
||||
before_thinking = content_buffer[:think_start.start()]
|
||||
if before_thinking:
|
||||
yield f"data: {json.dumps({'delta': before_thinking})}\n\n"
|
||||
|
||||
# Start capturing thinking content
|
||||
in_thinking = True
|
||||
content_buffer = content_buffer[think_start.end():]
|
||||
thinking_buffer = ""
|
||||
else:
|
||||
# No thinking tag found, send accumulated content
|
||||
if content_buffer:
|
||||
yield f"data: {json.dumps({'delta': content_buffer})}\n\n"
|
||||
content_buffer = ""
|
||||
break
|
||||
else:
|
||||
# We're inside a thinking tag, look for closing tag
|
||||
think_end = re.search(r'</think(?:ing)?>', content_buffer, re.IGNORECASE)
|
||||
if think_end:
|
||||
# Capture thinking content up to the closing tag
|
||||
thinking_buffer += content_buffer[:think_end.start()]
|
||||
|
||||
# Send the thinking content as a special type
|
||||
if thinking_buffer.strip():
|
||||
yield f"data: {json.dumps({'thinking': thinking_buffer.strip()})}\n\n"
|
||||
|
||||
# Continue processing after the closing tag
|
||||
in_thinking = False
|
||||
content_buffer = content_buffer[think_end.end():]
|
||||
thinking_buffer = ""
|
||||
else:
|
||||
# Still inside thinking tag, accumulate content
|
||||
thinking_buffer += content_buffer
|
||||
content_buffer = ""
|
||||
break
|
||||
|
||||
# Handle any remaining content
|
||||
if in_thinking and thinking_buffer:
|
||||
# Unclosed thinking tag - send as thinking content
|
||||
yield f"data: {json.dumps({'thinking': thinking_buffer.strip()})}\n\n"
|
||||
elif content_buffer:
|
||||
# Regular content
|
||||
yield f"data: {json.dumps({'delta': content_buffer})}\n\n"
|
||||
|
||||
# Signal the end of the stream
|
||||
yield f"data: {json.dumps({'end_of_stream': True})}\n\n"
|
||||
|
||||
|
||||
|
||||
219
src/services/retention.py
Normal file
219
src/services/retention.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Recording retention and auto-deletion services.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from flask import current_app
|
||||
|
||||
from src.database import db
|
||||
from src.models import Recording, RecordingTag, Tag
|
||||
|
||||
ENABLE_AUTO_DELETION = os.environ.get('ENABLE_AUTO_DELETION', 'false').lower() == 'true'
|
||||
GLOBAL_RETENTION_DAYS = int(os.environ.get('GLOBAL_RETENTION_DAYS', '0'))
|
||||
DELETION_MODE = os.environ.get('DELETION_MODE', 'full_recording')
|
||||
|
||||
|
||||
|
||||
def is_recording_exempt_from_deletion(recording):
|
||||
"""
|
||||
Check if a recording is exempt from auto-deletion.
|
||||
|
||||
Args:
|
||||
recording: Recording object to check
|
||||
|
||||
Returns:
|
||||
Boolean indicating if the recording should be kept
|
||||
"""
|
||||
# Manual exemption flag
|
||||
if recording.deletion_exempt:
|
||||
return True
|
||||
|
||||
# Check if any of the recording's tags protect it from deletion
|
||||
# Protection can be indicated by either protect_from_deletion flag OR retention_days == -1
|
||||
for tag_assoc in recording.tag_associations:
|
||||
if tag_assoc.tag.protect_from_deletion:
|
||||
return True
|
||||
if tag_assoc.tag.retention_days == -1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def get_retention_days_for_recording(recording):
|
||||
"""
|
||||
Get the effective retention period for a recording.
|
||||
Multi-tier system: tag retention (shortest) → global retention
|
||||
|
||||
Tags with retention_days set override the global retention policy.
|
||||
If multiple tags have retention_days, the SHORTEST period is used (most conservative).
|
||||
Note: retention_days == -1 indicates infinite retention (protected), which is handled separately.
|
||||
|
||||
Args:
|
||||
recording: Recording object
|
||||
|
||||
Returns:
|
||||
Integer days for retention period, or None if no retention applies
|
||||
"""
|
||||
# Collect all tag-level retention periods
|
||||
# Skip -1 (infinite retention/protected) as that's handled in is_recording_exempt_from_deletion
|
||||
tag_retention_periods = []
|
||||
for tag_assoc in recording.tag_associations:
|
||||
if tag_assoc.tag.retention_days and tag_assoc.tag.retention_days > 0:
|
||||
tag_retention_periods.append(tag_assoc.tag.retention_days)
|
||||
|
||||
# If any tags have retention periods, use the shortest one (most conservative)
|
||||
if tag_retention_periods:
|
||||
return min(tag_retention_periods)
|
||||
|
||||
# Fall back to global retention
|
||||
if GLOBAL_RETENTION_DAYS > 0:
|
||||
return GLOBAL_RETENTION_DAYS
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def process_auto_deletion():
|
||||
"""
|
||||
Process auto-deletion of recordings based on retention policies.
|
||||
This can be called by a scheduled job or admin endpoint.
|
||||
|
||||
Supports per-recording retention via tag-level retention_days overrides.
|
||||
Tags with retention_days set take precedence over global retention.
|
||||
|
||||
Returns:
|
||||
Dictionary with deletion statistics
|
||||
"""
|
||||
if not ENABLE_AUTO_DELETION:
|
||||
return {'error': 'Auto-deletion is not enabled'}
|
||||
|
||||
# Check if any retention policy exists (global or tag-level)
|
||||
has_global_retention = GLOBAL_RETENTION_DAYS > 0
|
||||
# We'll check for tag-level retention on a per-recording basis
|
||||
|
||||
if not has_global_retention:
|
||||
# Still check recordings in case they have tag-level retention
|
||||
current_app.logger.info("No global retention configured, checking for tag-level retention policies")
|
||||
|
||||
stats = {
|
||||
'checked': 0,
|
||||
'deleted_audio_only': 0,
|
||||
'deleted_full': 0,
|
||||
'exempted': 0,
|
||||
'skipped_no_retention': 0,
|
||||
'errors': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# Get completed recordings to check
|
||||
# In audio_only mode: Skip recordings where audio was already deleted
|
||||
# In full_recording mode: Include all (to catch audio-only deletions for full cleanup)
|
||||
if DELETION_MODE == 'audio_only':
|
||||
all_recordings = Recording.query.filter(
|
||||
Recording.status == 'COMPLETED',
|
||||
Recording.audio_deleted_at.is_(None) # Skip already-deleted audio
|
||||
).all()
|
||||
else: # full_recording mode
|
||||
all_recordings = Recording.query.filter(
|
||||
Recording.status == 'COMPLETED'
|
||||
).all()
|
||||
|
||||
stats['checked'] = len(all_recordings)
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
for recording in all_recordings:
|
||||
try:
|
||||
# Check if exempt from deletion entirely
|
||||
if is_recording_exempt_from_deletion(recording):
|
||||
stats['exempted'] += 1
|
||||
continue
|
||||
|
||||
# Get the effective retention period for this specific recording
|
||||
retention_days = get_retention_days_for_recording(recording)
|
||||
|
||||
if not retention_days:
|
||||
# No retention policy applies to this recording
|
||||
stats['skipped_no_retention'] += 1
|
||||
continue
|
||||
|
||||
# Calculate the cutoff date for this specific recording
|
||||
cutoff_date = current_time - timedelta(days=retention_days)
|
||||
|
||||
# Check if recording is past its retention period
|
||||
if recording.created_at >= cutoff_date:
|
||||
# Recording is still within retention period
|
||||
continue
|
||||
|
||||
# Recording is past retention period - process deletion
|
||||
|
||||
# Determine deletion mode
|
||||
if DELETION_MODE == 'audio_only':
|
||||
# Delete only the audio file, keep transcription
|
||||
if recording.audio_path and os.path.exists(recording.audio_path):
|
||||
current_app.logger.info(f"Recording {recording.id} is past retention ({retention_days} days), deleting audio")
|
||||
os.remove(recording.audio_path)
|
||||
current_app.logger.info(f"Auto-deleted audio file: {recording.audio_path}")
|
||||
recording.audio_deleted_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
stats['deleted_audio_only'] += 1
|
||||
else:
|
||||
# Audio already deleted or doesn't exist - just mark timestamp
|
||||
if not recording.audio_deleted_at:
|
||||
recording.audio_deleted_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
current_app.logger.debug(f"Recording {recording.id} audio file not found, marked as deleted")
|
||||
|
||||
else: # full_recording mode
|
||||
# Check if this is completing a previous audio_only deletion
|
||||
if recording.audio_deleted_at:
|
||||
current_app.logger.info(f"Recording {recording.id} has deleted audio (mode changed), completing full deletion")
|
||||
else:
|
||||
current_app.logger.info(f"Recording {recording.id} is past retention ({retention_days} days), deleting fully")
|
||||
|
||||
# Delete audio file if it exists
|
||||
if recording.audio_path and os.path.exists(recording.audio_path):
|
||||
os.remove(recording.audio_path)
|
||||
|
||||
# Delete associated processing jobs (required due to NOT NULL constraint)
|
||||
from src.models.processing_job import ProcessingJob
|
||||
ProcessingJob.query.filter_by(recording_id=recording.id).delete()
|
||||
|
||||
# Delete the database record (cascades to chunks, shares, etc.)
|
||||
db.session.delete(recording)
|
||||
db.session.commit()
|
||||
stats['deleted_full'] += 1
|
||||
current_app.logger.info(f"Auto-deleted full recording ID: {recording.id}")
|
||||
|
||||
except Exception as e:
|
||||
stats['errors'] += 1
|
||||
current_app.logger.error(f"Error auto-deleting recording {recording.id}: {e}")
|
||||
db.session.rollback()
|
||||
|
||||
# After processing recording deletions, clean up orphaned speaker profiles
|
||||
try:
|
||||
from src.services.speaker_cleanup import cleanup_orphaned_speakers
|
||||
speaker_stats = cleanup_orphaned_speakers()
|
||||
stats['speakers_deleted'] = speaker_stats['speakers_deleted']
|
||||
stats['embeddings_cleaned'] = speaker_stats['embeddings_removed']
|
||||
stats['speakers_evaluated'] = speaker_stats['speakers_evaluated']
|
||||
current_app.logger.info(
|
||||
f"Speaker cleanup completed: {speaker_stats['speakers_deleted']} speakers deleted, "
|
||||
f"{speaker_stats['embeddings_removed']} embedding references removed"
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error during speaker cleanup: {e}", exc_info=True)
|
||||
stats['speaker_cleanup_error'] = str(e)
|
||||
|
||||
current_app.logger.info(f"Auto-deletion completed: {stats}")
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error during auto-deletion process: {e}", exc_info=True)
|
||||
return {'error': str(e)}
|
||||
|
||||
# --- API client setup for OpenRouter ---
|
||||
# Use environment variables from .env
|
||||
|
||||
|
||||
217
src/services/speaker.py
Normal file
217
src/services/speaker.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Speaker identification and management services.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from src.database import db
|
||||
from src.models import Speaker, SystemSetting
|
||||
from src.services.llm import call_llm_completion
|
||||
from src.utils import safe_json_loads
|
||||
|
||||
# NOTE: format_transcription_for_llm is referenced but not defined - needs to be implemented
|
||||
def format_transcription_for_llm(transcription):
|
||||
"""
|
||||
Format transcription for LLM processing.
|
||||
|
||||
TODO: This function needs proper implementation.
|
||||
If transcription is JSON, extract and format the text.
|
||||
Otherwise return as-is.
|
||||
"""
|
||||
if isinstance(transcription, str):
|
||||
try:
|
||||
import json
|
||||
data = json.loads(transcription)
|
||||
# If it's JSON diarized format, extract text
|
||||
if isinstance(data, list):
|
||||
return '\n'.join([f"[{seg.get('speaker', 'UNKNOWN')}] {seg.get('text', '')}"
|
||||
for seg in data if 'text' in seg])
|
||||
except:
|
||||
pass
|
||||
return str(transcription)
|
||||
|
||||
# Import TEXT_MODEL_API_KEY from llm service
|
||||
from src.services.llm import TEXT_MODEL_API_KEY
|
||||
|
||||
|
||||
def update_speaker_usage(speaker_names):
|
||||
"""Helper function to update speaker usage statistics."""
|
||||
if not speaker_names or not current_user.is_authenticated:
|
||||
return
|
||||
|
||||
try:
|
||||
for name in speaker_names:
|
||||
name = name.strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
speaker = Speaker.query.filter_by(user_id=current_user.id, name=name).first()
|
||||
if speaker:
|
||||
speaker.use_count += 1
|
||||
speaker.last_used = datetime.utcnow()
|
||||
else:
|
||||
# Create new speaker
|
||||
speaker = Speaker(
|
||||
name=name,
|
||||
user_id=current_user.id,
|
||||
use_count=1,
|
||||
created_at=datetime.utcnow(),
|
||||
last_used=datetime.utcnow()
|
||||
)
|
||||
db.session.add(speaker)
|
||||
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error updating speaker usage: {e}")
|
||||
db.session.rollback()
|
||||
|
||||
|
||||
|
||||
def identify_speakers_from_text(transcription):
|
||||
"""
|
||||
Uses an LLM to identify speakers from a transcription.
|
||||
"""
|
||||
if not TEXT_MODEL_API_KEY:
|
||||
raise ValueError("TEXT_MODEL_API_KEY not configured.")
|
||||
|
||||
# The transcription passed here could be JSON, so we format it.
|
||||
formatted_transcription = format_transcription_for_llm(transcription)
|
||||
|
||||
# Extract existing speaker labels (e.g., SPEAKER_00, SPEAKER_01) in order of appearance
|
||||
all_labels = re.findall(r'\[(SPEAKER_\d+)\]', formatted_transcription)
|
||||
seen = set()
|
||||
speaker_labels = [x for x in all_labels if not (x in seen or seen.add(x))]
|
||||
|
||||
if not speaker_labels:
|
||||
return {}
|
||||
|
||||
# Get configurable transcript length limit
|
||||
transcript_limit = SystemSetting.get_setting('transcript_length_limit', 30000)
|
||||
if transcript_limit == -1:
|
||||
# No limit
|
||||
transcript_text = formatted_transcription
|
||||
else:
|
||||
transcript_text = formatted_transcription[:transcript_limit]
|
||||
|
||||
prompt = f"""Analyze the following transcription and identify the names of the speakers. The speakers are labeled as {', '.join(speaker_labels)}. Based on the context of the conversation, determine the most likely name for each speaker label.
|
||||
|
||||
Transcription:
|
||||
---
|
||||
{transcript_text}
|
||||
---
|
||||
|
||||
Respond with a single JSON object where keys are the speaker labels (e.g., "SPEAKER_00") and values are the identified full names. If a name cannot be determined, use the value "Unknown".
|
||||
|
||||
Example:
|
||||
{{
|
||||
"SPEAKER_00": "John Doe",
|
||||
"SPEAKER_01": "Jane Smith",
|
||||
"SPEAKER_02": "Unknown"
|
||||
}}
|
||||
|
||||
JSON Response:
|
||||
"""
|
||||
|
||||
try:
|
||||
completion = call_llm_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are an expert in analyzing conversation transcripts to identify speakers. Your response must be a single, valid JSON object."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.2
|
||||
)
|
||||
response_content = completion.choices[0].message.content
|
||||
speaker_map = safe_json_loads(response_content, {})
|
||||
|
||||
# Post-process the map to replace "Unknown" with an empty string
|
||||
for speaker_label, identified_name in speaker_map.items():
|
||||
if identified_name.strip().lower() == "unknown":
|
||||
speaker_map[speaker_label] = ""
|
||||
|
||||
return speaker_map
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error calling LLM for speaker identification: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def identify_unidentified_speakers_from_text(transcription, unidentified_speakers):
|
||||
"""
|
||||
Uses an LLM to identify only the unidentified speakers from a transcription.
|
||||
"""
|
||||
if not TEXT_MODEL_API_KEY:
|
||||
raise ValueError("TEXT_MODEL_API_KEY not configured.")
|
||||
|
||||
# The transcription passed here could be JSON, so we format it.
|
||||
formatted_transcription = format_transcription_for_llm(transcription)
|
||||
|
||||
if not unidentified_speakers:
|
||||
return {}
|
||||
|
||||
# Get configurable transcript length limit
|
||||
transcript_limit = SystemSetting.get_setting('transcript_length_limit', 30000)
|
||||
if transcript_limit == -1:
|
||||
# No limit
|
||||
transcript_text = formatted_transcription
|
||||
else:
|
||||
transcript_text = formatted_transcription[:transcript_limit]
|
||||
|
||||
prompt = f"""Analyze the following conversation transcript and identify the names of the UNIDENTIFIED speakers based on the context and content of their dialogue.
|
||||
|
||||
The speakers that need to be identified are: {', '.join(unidentified_speakers)}
|
||||
|
||||
Look for clues in the conversation such as:
|
||||
- Names mentioned by other speakers when addressing someone
|
||||
- Self-introductions or references to their own name
|
||||
- Context clues about roles, relationships, or positions
|
||||
- Any direct mentions of names in the dialogue
|
||||
|
||||
Here is the complete conversation transcript:
|
||||
|
||||
{transcript_text}
|
||||
|
||||
Based on the conversation above, identify the most likely real names for the unidentified speakers. Pay close attention to how speakers address each other and any names that are mentioned in the dialogue.
|
||||
|
||||
Respond with a single JSON object where keys are the speaker labels (e.g., "SPEAKER_01") and values are the identified full names. If a name cannot be determined from the conversation context, use an empty string "".
|
||||
|
||||
Example format:
|
||||
{{
|
||||
"SPEAKER_01": "Jane Smith",
|
||||
"SPEAKER_03": "Bob Johnson",
|
||||
"SPEAKER_05": ""
|
||||
}}
|
||||
|
||||
JSON Response:
|
||||
"""
|
||||
|
||||
try:
|
||||
current_app.logger.info(f"[Auto-Identify] Calling LLM to identify speakers: {unidentified_speakers}")
|
||||
current_app.logger.info(f"[Auto-Identify] Transcript excerpt (first 500 chars): {transcript_text[:500]}")
|
||||
|
||||
completion = call_llm_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are an expert in analyzing conversation transcripts to identify speakers based on contextual clues in the dialogue. Analyze the conversation carefully to find names mentioned when speakers address each other or introduce themselves. Your response must be a single, valid JSON object containing only the requested speaker identifications."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.2
|
||||
)
|
||||
response_content = completion.choices[0].message.content
|
||||
current_app.logger.info(f"[Auto-Identify] LLM Raw Response: {response_content}")
|
||||
|
||||
speaker_map = safe_json_loads(response_content, {})
|
||||
current_app.logger.info(f"[Auto-Identify] Parsed speaker_map: {speaker_map}")
|
||||
|
||||
# Post-process the map to replace "Unknown" with an empty string
|
||||
for speaker_label, identified_name in speaker_map.items():
|
||||
if identified_name and identified_name.strip().lower() in ["unknown", "n/a", "not available", "unclear"]:
|
||||
speaker_map[speaker_label] = ""
|
||||
|
||||
current_app.logger.info(f"[Auto-Identify] Final speaker_map after post-processing: {speaker_map}")
|
||||
return speaker_map
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error calling LLM for speaker identification: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
295
src/services/speaker_cleanup.py
Normal file
295
src/services/speaker_cleanup.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Speaker cleanup service for managing orphaned speaker voice profiles.
|
||||
|
||||
This module provides automatic cleanup of speaker records when their associated
|
||||
recordings are deleted through auto-deletion or manual deletion processes.
|
||||
|
||||
By default, speaker profiles (including voice embeddings) are preserved even
|
||||
when all their recordings are deleted, since embeddings are aggregated and
|
||||
represent hours of manual identification work. Set DELETE_ORPHANED_SPEAKERS=true
|
||||
to enable automatic cleanup of speakers with no remaining recordings.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
from sqlalchemy import exists
|
||||
from src.database import db
|
||||
from src.models import Speaker, SpeakerSnippet, Recording
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cleanup_orphaned_speakers(dry_run=False):
|
||||
"""
|
||||
Clean up speaker records that no longer have any associated recordings.
|
||||
|
||||
Only runs if DELETE_ORPHANED_SPEAKERS=true is set. By default, speaker
|
||||
profiles are preserved because voice embeddings are aggregated values
|
||||
that can't be reconstructed from recordings alone.
|
||||
|
||||
A speaker is considered orphaned when:
|
||||
- It has no SpeakerSnippet records
|
||||
- Its embeddings_history contains no valid recording references
|
||||
|
||||
Args:
|
||||
dry_run (bool): If True, only report what would be deleted without actually deleting
|
||||
|
||||
Returns:
|
||||
dict: Statistics about cleanup operation
|
||||
{
|
||||
'speakers_deleted': int,
|
||||
'embeddings_removed': int,
|
||||
'speakers_evaluated': int,
|
||||
'orphaned_speakers': list of dict (if dry_run=True)
|
||||
}
|
||||
"""
|
||||
delete_orphans = os.environ.get('DELETE_ORPHANED_SPEAKERS', 'false').lower() in ('true', '1', 'yes')
|
||||
|
||||
if not delete_orphans:
|
||||
logger.debug("Speaker cleanup skipped (DELETE_ORPHANED_SPEAKERS is not enabled)")
|
||||
return {
|
||||
'speakers_deleted': 0,
|
||||
'embeddings_removed': 0,
|
||||
'speakers_evaluated': 0,
|
||||
'orphaned_speakers': []
|
||||
}
|
||||
|
||||
logger.info("Starting speaker cleanup process (dry_run=%s)", dry_run)
|
||||
|
||||
stats = {
|
||||
'speakers_deleted': 0,
|
||||
'embeddings_removed': 0,
|
||||
'speakers_evaluated': 0,
|
||||
'orphaned_speakers': []
|
||||
}
|
||||
|
||||
try:
|
||||
# Clean embeddings_history references first
|
||||
embeddings_cleaned = clean_embeddings_history_references(dry_run=dry_run)
|
||||
stats['embeddings_removed'] = embeddings_cleaned
|
||||
|
||||
# Find and process orphaned speakers
|
||||
orphaned_speaker_ids = get_orphaned_speakers()
|
||||
stats['speakers_evaluated'] = Speaker.query.count()
|
||||
|
||||
if not orphaned_speaker_ids:
|
||||
logger.info("No orphaned speakers found")
|
||||
return stats
|
||||
|
||||
logger.info("Found %d orphaned speaker(s)", len(orphaned_speaker_ids))
|
||||
|
||||
if dry_run:
|
||||
# Report what would be deleted
|
||||
for speaker_id in orphaned_speaker_ids:
|
||||
speaker = Speaker.query.get(speaker_id)
|
||||
if speaker:
|
||||
stats['orphaned_speakers'].append({
|
||||
'id': speaker.id,
|
||||
'name': speaker.name,
|
||||
'user_id': speaker.user_id,
|
||||
'embedding_count': speaker.embedding_count
|
||||
})
|
||||
logger.info("Dry run: Would delete %d speakers", len(orphaned_speaker_ids))
|
||||
else:
|
||||
# Actually delete orphaned speakers
|
||||
for speaker_id in orphaned_speaker_ids:
|
||||
speaker = Speaker.query.get(speaker_id)
|
||||
if speaker:
|
||||
logger.debug(
|
||||
"Deleting orphaned speaker: id=%d, name='%s', user_id=%d, embedding_count=%d",
|
||||
speaker.id, speaker.name, speaker.user_id, speaker.embedding_count or 0
|
||||
)
|
||||
db.session.delete(speaker)
|
||||
stats['speakers_deleted'] += 1
|
||||
|
||||
# Commit all deletions
|
||||
db.session.commit()
|
||||
logger.info("Speaker cleanup completed: %d speakers deleted", stats['speakers_deleted'])
|
||||
|
||||
# Warning if large number deleted
|
||||
if stats['speakers_deleted'] >= 50:
|
||||
logger.warning(
|
||||
"Large number of speakers deleted (%d). Review cleanup logic if unexpected.",
|
||||
stats['speakers_deleted']
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error("Error during speaker cleanup: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def clean_embeddings_history_references(dry_run=False):
|
||||
"""
|
||||
Clean embeddings_history JSON fields to remove references to deleted recordings.
|
||||
|
||||
Scans all speakers' embeddings_history and removes entries where the
|
||||
recording_id no longer exists in the database.
|
||||
|
||||
Args:
|
||||
dry_run (bool): If True, only count what would be cleaned
|
||||
|
||||
Returns:
|
||||
int: Number of embedding references removed
|
||||
"""
|
||||
logger.debug("Cleaning embeddings_history references (dry_run=%s)", dry_run)
|
||||
|
||||
references_removed = 0
|
||||
|
||||
try:
|
||||
# Get all speakers with embeddings_history
|
||||
speakers = Speaker.query.filter(Speaker.embeddings_history.isnot(None)).all()
|
||||
|
||||
for speaker in speakers:
|
||||
try:
|
||||
# Parse embeddings_history JSON
|
||||
if not speaker.embeddings_history:
|
||||
continue
|
||||
|
||||
history = speaker.embeddings_history if isinstance(speaker.embeddings_history, list) else json.loads(speaker.embeddings_history)
|
||||
|
||||
if not history or not isinstance(history, list):
|
||||
continue
|
||||
|
||||
# Filter out entries with deleted recording_ids
|
||||
cleaned_history = []
|
||||
for entry in history:
|
||||
if not isinstance(entry, dict) or 'recording_id' not in entry:
|
||||
continue
|
||||
|
||||
recording_id = entry['recording_id']
|
||||
|
||||
# Check if recording still exists
|
||||
recording_exists = db.session.query(
|
||||
exists().where(Recording.id == recording_id)
|
||||
).scalar()
|
||||
|
||||
if recording_exists:
|
||||
cleaned_history.append(entry)
|
||||
else:
|
||||
references_removed += 1
|
||||
logger.debug(
|
||||
"Removing deleted recording reference: speaker_id=%d, recording_id=%d",
|
||||
speaker.id, recording_id
|
||||
)
|
||||
|
||||
# Update speaker if history changed
|
||||
if len(cleaned_history) < len(history):
|
||||
if not dry_run:
|
||||
speaker.embeddings_history = cleaned_history
|
||||
logger.debug(
|
||||
"Updated speaker %d embeddings_history: %d -> %d entries",
|
||||
speaker.id, len(history), len(cleaned_history)
|
||||
)
|
||||
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
logger.warning(
|
||||
"Error processing embeddings_history for speaker %d: %s",
|
||||
speaker.id, str(e)
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run and references_removed > 0:
|
||||
db.session.commit()
|
||||
logger.debug("Cleaned %d embedding references", references_removed)
|
||||
|
||||
return references_removed
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error("Error cleaning embeddings_history: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def get_orphaned_speakers(user_id=None):
|
||||
"""
|
||||
Get list of speaker IDs that are orphaned (no associated recordings).
|
||||
|
||||
A speaker is orphaned when:
|
||||
- It has no SpeakerSnippet records
|
||||
- After cleaning embeddings_history, it has no valid recording references
|
||||
|
||||
Args:
|
||||
user_id (int, optional): Filter to specific user's speakers
|
||||
|
||||
Returns:
|
||||
list: List of speaker IDs that are orphaned
|
||||
"""
|
||||
logger.debug("Finding orphaned speakers (user_id=%s)", user_id)
|
||||
|
||||
# Query for speakers with no snippets
|
||||
query = Speaker.query.filter(
|
||||
~exists().where(SpeakerSnippet.speaker_id == Speaker.id)
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
query = query.filter(Speaker.user_id == user_id)
|
||||
|
||||
speakers_without_snippets = query.all()
|
||||
|
||||
orphaned_ids = []
|
||||
|
||||
for speaker in speakers_without_snippets:
|
||||
# Check if embeddings_history has any valid recording references
|
||||
has_valid_recordings = False
|
||||
|
||||
if speaker.embeddings_history:
|
||||
try:
|
||||
history = speaker.embeddings_history if isinstance(speaker.embeddings_history, list) else json.loads(speaker.embeddings_history)
|
||||
|
||||
if history and isinstance(history, list):
|
||||
for entry in history:
|
||||
if isinstance(entry, dict) and 'recording_id' in entry:
|
||||
recording_id = entry['recording_id']
|
||||
|
||||
# Check if this recording exists
|
||||
recording_exists = db.session.query(
|
||||
exists().where(Recording.id == recording_id)
|
||||
).scalar()
|
||||
|
||||
if recording_exists:
|
||||
has_valid_recordings = True
|
||||
break
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
pass
|
||||
|
||||
# If no snippets AND no valid recording references, it's orphaned
|
||||
if not has_valid_recordings:
|
||||
orphaned_ids.append(speaker.id)
|
||||
logger.debug(
|
||||
"Speaker %d ('%s') is orphaned: no snippets, no valid recordings",
|
||||
speaker.id, speaker.name
|
||||
)
|
||||
|
||||
return orphaned_ids
|
||||
|
||||
|
||||
def get_speaker_cleanup_statistics():
|
||||
"""
|
||||
Get statistics about speaker data for monitoring.
|
||||
|
||||
Returns:
|
||||
dict: Statistics about speakers
|
||||
{
|
||||
'total_speakers': int,
|
||||
'speakers_with_snippets': int,
|
||||
'speakers_with_embeddings': int,
|
||||
'potential_orphans': int
|
||||
}
|
||||
"""
|
||||
stats = {
|
||||
'total_speakers': Speaker.query.count(),
|
||||
'speakers_with_snippets': db.session.query(Speaker.id).join(
|
||||
SpeakerSnippet, Speaker.id == SpeakerSnippet.speaker_id
|
||||
).distinct().count(),
|
||||
'speakers_with_embeddings': Speaker.query.filter(
|
||||
Speaker.average_embedding.isnot(None)
|
||||
).count(),
|
||||
'potential_orphans': len(get_orphaned_speakers())
|
||||
}
|
||||
|
||||
return stats
|
||||
453
src/services/speaker_embedding_matcher.py
Normal file
453
src/services/speaker_embedding_matcher.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Speaker Embedding Matcher Service.
|
||||
|
||||
This service handles voice embedding comparison and matching for speaker identification.
|
||||
It provides functions to:
|
||||
- Serialize/deserialize speaker embeddings for database storage
|
||||
- Calculate cosine similarity between voice embeddings
|
||||
- Find matching speakers based on voice similarity
|
||||
- Update speaker profiles with new embeddings
|
||||
- Calculate confidence scores for speaker profiles
|
||||
|
||||
Uses 256-dimensional embeddings from WhisperX diarization.
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
try:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
except ImportError:
|
||||
cosine_similarity = None
|
||||
from src.database import db
|
||||
from src.models import Speaker
|
||||
|
||||
|
||||
def serialize_embedding(embedding_array):
|
||||
"""
|
||||
Convert numpy array or list to binary for database storage.
|
||||
|
||||
Args:
|
||||
embedding_array: numpy array or list of floats (256 dimensions)
|
||||
|
||||
Returns:
|
||||
bytes: Binary representation (1,024 bytes for 256 × float32)
|
||||
"""
|
||||
return np.array(embedding_array, dtype=np.float32).tobytes()
|
||||
|
||||
|
||||
def deserialize_embedding(binary_data):
|
||||
"""
|
||||
Convert binary data back to numpy array.
|
||||
|
||||
Args:
|
||||
binary_data: bytes from database (1,024 bytes)
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 256-dimensional float32 array
|
||||
"""
|
||||
return np.frombuffer(binary_data, dtype=np.float32)
|
||||
|
||||
|
||||
def calculate_similarity(embedding1, embedding2):
|
||||
"""
|
||||
Compute cosine similarity between two 256-dimensional voice embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: numpy array, list, or binary data
|
||||
embedding2: numpy array, list, or binary data
|
||||
|
||||
Returns:
|
||||
float: Similarity score (0-1, where 1 is identical)
|
||||
"""
|
||||
# Convert to numpy arrays if needed
|
||||
e1 = np.array(embedding1, dtype=np.float32).reshape(1, -1)
|
||||
e2 = np.array(embedding2, dtype=np.float32).reshape(1, -1)
|
||||
|
||||
# Cosine similarity returns values from -1 to 1
|
||||
# For voice embeddings, we typically see 0.6-0.99 range
|
||||
return float(cosine_similarity(e1, e2)[0][0])
|
||||
|
||||
|
||||
def find_matching_speakers(target_embedding, user_id, threshold=0.70):
|
||||
"""
|
||||
Find speakers matching a target voice embedding for a specific user.
|
||||
|
||||
Args:
|
||||
target_embedding: The voice embedding to match against (256-dim array/list)
|
||||
user_id: User ID to search within
|
||||
threshold: Minimum similarity score (0-1, default 0.70 = 70%)
|
||||
|
||||
Returns:
|
||||
list: Sorted list of matching speakers with scores
|
||||
[{'speaker_id': 5, 'name': 'John', 'similarity': 85.3, 'confidence': 0.92}, ...]
|
||||
"""
|
||||
# Get all speakers with embeddings for this user
|
||||
speakers = Speaker.query.filter_by(user_id=user_id).filter(
|
||||
Speaker.average_embedding.isnot(None)
|
||||
).all()
|
||||
|
||||
if not speakers:
|
||||
return []
|
||||
|
||||
matches = []
|
||||
for speaker in speakers:
|
||||
try:
|
||||
# Deserialize and compare
|
||||
speaker_emb = deserialize_embedding(speaker.average_embedding)
|
||||
similarity = calculate_similarity(target_embedding, speaker_emb)
|
||||
|
||||
if similarity >= threshold:
|
||||
matches.append({
|
||||
'speaker_id': speaker.id,
|
||||
'name': speaker.name,
|
||||
'similarity': round(similarity * 100, 1), # Convert to percentage
|
||||
'confidence': speaker.confidence_score or 0.5,
|
||||
'embedding_count': speaker.embedding_count or 0
|
||||
})
|
||||
except Exception as e:
|
||||
# Skip speakers with corrupted embeddings
|
||||
continue
|
||||
|
||||
# Sort by similarity (highest first)
|
||||
return sorted(matches, key=lambda x: x['similarity'], reverse=True)
|
||||
|
||||
|
||||
def update_speaker_embedding(speaker, new_embedding, recording_id):
|
||||
"""
|
||||
Update a speaker's average embedding and history with a new sample.
|
||||
|
||||
Uses weighted moving average to update the profile:
|
||||
- New embeddings get 30% weight
|
||||
- Existing average gets 70% weight
|
||||
|
||||
Args:
|
||||
speaker: Speaker model instance
|
||||
new_embedding: New voice embedding (256-dim array/list)
|
||||
recording_id: ID of the recording this embedding came from
|
||||
|
||||
Returns:
|
||||
float: Similarity between new embedding and previous average (None if first)
|
||||
"""
|
||||
new_emb_array = np.array(new_embedding, dtype=np.float32)
|
||||
similarity_to_avg = None
|
||||
|
||||
if speaker.average_embedding is None:
|
||||
# First embedding for this speaker
|
||||
speaker.average_embedding = serialize_embedding(new_emb_array)
|
||||
speaker.embedding_count = 1
|
||||
speaker.embeddings_history = [{
|
||||
'recording_id': recording_id,
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'similarity': 100.0 # Perfect match to itself
|
||||
}]
|
||||
else:
|
||||
# Update existing average
|
||||
current_avg = deserialize_embedding(speaker.average_embedding)
|
||||
similarity_to_avg = calculate_similarity(new_emb_array, current_avg)
|
||||
|
||||
# Weighted average: 30% new, 70% existing
|
||||
# This prevents sudden shifts while still adapting to voice changes
|
||||
weight = 0.3
|
||||
updated_avg = (1 - weight) * current_avg + weight * new_emb_array
|
||||
|
||||
speaker.average_embedding = serialize_embedding(updated_avg)
|
||||
speaker.embedding_count += 1
|
||||
|
||||
# Add to history (keep last 10 entries)
|
||||
history = speaker.embeddings_history or []
|
||||
history.append({
|
||||
'recording_id': recording_id,
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'similarity': round(similarity_to_avg * 100, 1)
|
||||
})
|
||||
speaker.embeddings_history = history[-10:] # Keep most recent 10
|
||||
|
||||
# Recalculate confidence score
|
||||
speaker.confidence_score = calculate_confidence(speaker)
|
||||
|
||||
# Commit changes
|
||||
db.session.commit()
|
||||
|
||||
return similarity_to_avg
|
||||
|
||||
|
||||
def calculate_confidence(speaker):
|
||||
"""
|
||||
Calculate confidence score based on embedding consistency.
|
||||
|
||||
Confidence is based on:
|
||||
- Number of samples (more is better)
|
||||
- Consistency of embeddings (high similarity scores = high confidence)
|
||||
|
||||
Args:
|
||||
speaker: Speaker model instance with embeddings_history
|
||||
|
||||
Returns:
|
||||
float: Confidence score (0-1)
|
||||
"""
|
||||
if speaker.embedding_count is None or speaker.embedding_count < 1:
|
||||
return 0.0
|
||||
|
||||
if speaker.embedding_count == 1:
|
||||
return 0.5 # Medium confidence with single sample
|
||||
|
||||
# Get recent similarity scores from history
|
||||
history = speaker.embeddings_history or []
|
||||
if len(history) < 2:
|
||||
return 0.5
|
||||
|
||||
# Use last 5 samples
|
||||
recent_history = history[-5:]
|
||||
similarities = [h.get('similarity', 0) / 100.0 for h in recent_history]
|
||||
|
||||
# Average similarity to the profile
|
||||
avg_similarity = sum(similarities) / len(similarities)
|
||||
|
||||
# Penalize if we have very few samples
|
||||
sample_factor = min(1.0, speaker.embedding_count / 5.0)
|
||||
|
||||
# Confidence = average similarity × sample factor
|
||||
confidence = avg_similarity * sample_factor
|
||||
|
||||
return min(1.0, max(0.0, confidence))
|
||||
|
||||
|
||||
def get_speaker_voice_profile_summary(speaker):
|
||||
"""
|
||||
Get a human-readable summary of a speaker's voice profile.
|
||||
|
||||
Args:
|
||||
speaker: Speaker model instance
|
||||
|
||||
Returns:
|
||||
dict: Profile summary with statistics and status
|
||||
"""
|
||||
if not speaker.average_embedding:
|
||||
return {
|
||||
'has_profile': False,
|
||||
'message': 'No voice profile yet'
|
||||
}
|
||||
|
||||
return {
|
||||
'has_profile': True,
|
||||
'embedding_count': speaker.embedding_count or 0,
|
||||
'confidence_score': speaker.confidence_score or 0.0,
|
||||
'confidence_level': _get_confidence_level(speaker.confidence_score),
|
||||
'last_updated': speaker.embeddings_history[-1]['timestamp'] if speaker.embeddings_history else None,
|
||||
'recordings': len(speaker.embeddings_history or [])
|
||||
}
|
||||
|
||||
|
||||
def _get_confidence_level(score):
|
||||
"""
|
||||
Convert numeric confidence score to human-readable level.
|
||||
|
||||
Args:
|
||||
score: float (0-1)
|
||||
|
||||
Returns:
|
||||
str: 'low', 'medium', or 'high'
|
||||
"""
|
||||
if score is None or score < 0.6:
|
||||
return 'low'
|
||||
elif score < 0.8:
|
||||
return 'medium'
|
||||
else:
|
||||
return 'high'
|
||||
|
||||
|
||||
# Threshold mapping for auto-labelling
|
||||
AUTO_LABEL_THRESHOLDS = {
|
||||
'low': 0.3, # Aggressive, may have more false positives
|
||||
'medium': 0.6, # Default, balanced approach
|
||||
'high': 0.8 # Only auto-label well-established speakers
|
||||
}
|
||||
|
||||
# Base similarity threshold for finding matches (70%)
|
||||
BASE_SIMILARITY_THRESHOLD = 0.70
|
||||
|
||||
# Ambiguity threshold: if top 2 matches are within 5% similarity, skip
|
||||
AMBIGUITY_MARGIN = 0.05
|
||||
|
||||
|
||||
def apply_auto_speaker_labels(recording, user):
|
||||
"""
|
||||
Automatically label speakers in a recording based on voice profile matching.
|
||||
|
||||
This function matches speaker embeddings from the recording against the user's
|
||||
saved speaker profiles and returns a mapping of generic labels to speaker names.
|
||||
|
||||
Args:
|
||||
recording: Recording model instance with speaker_embeddings
|
||||
user: User model instance with auto_speaker_labelling settings
|
||||
|
||||
Returns:
|
||||
dict: Mapping of {SPEAKER_XX: speaker_name} for matched speakers,
|
||||
or empty dict if auto-labelling is disabled or no matches found
|
||||
"""
|
||||
# Check if user has auto-labelling enabled
|
||||
if not user.auto_speaker_labelling:
|
||||
return {}
|
||||
|
||||
# Check if recording has speaker embeddings
|
||||
if not recording.speaker_embeddings:
|
||||
return {}
|
||||
|
||||
# Get the user's threshold setting
|
||||
threshold_setting = user.auto_speaker_labelling_threshold or 'medium'
|
||||
confidence_threshold = AUTO_LABEL_THRESHOLDS.get(threshold_setting, AUTO_LABEL_THRESHOLDS['medium'])
|
||||
|
||||
speaker_map = {}
|
||||
embeddings = recording.speaker_embeddings
|
||||
|
||||
for speaker_label, embedding_data in embeddings.items():
|
||||
# embedding_data should be a list of floats (256 dimensions)
|
||||
if not embedding_data or not isinstance(embedding_data, list):
|
||||
continue
|
||||
|
||||
# Find matching speakers with base similarity threshold
|
||||
matches = find_matching_speakers(
|
||||
target_embedding=embedding_data,
|
||||
user_id=user.id,
|
||||
threshold=BASE_SIMILARITY_THRESHOLD
|
||||
)
|
||||
|
||||
if not matches:
|
||||
continue
|
||||
|
||||
# Check if the best match exceeds the user's confidence threshold
|
||||
best_match = matches[0]
|
||||
best_similarity = best_match['similarity'] / 100.0 # Convert from percentage
|
||||
|
||||
if best_similarity < confidence_threshold:
|
||||
continue
|
||||
|
||||
# Check for ambiguity: if top 2 matches are within 5% similarity, skip
|
||||
if len(matches) >= 2:
|
||||
second_similarity = matches[1]['similarity'] / 100.0
|
||||
if (best_similarity - second_similarity) <= AMBIGUITY_MARGIN:
|
||||
# Ambiguous - top 2 matches too close
|
||||
continue
|
||||
|
||||
# We have a clear winner - add to speaker map
|
||||
speaker_map[speaker_label] = best_match['name']
|
||||
|
||||
return speaker_map
|
||||
|
||||
|
||||
def apply_speaker_names_to_transcription(recording, speaker_map):
|
||||
"""
|
||||
Apply speaker name mappings to a recording's transcription.
|
||||
|
||||
This function updates the transcription JSON by replacing generic speaker
|
||||
labels (SPEAKER_00, SPEAKER_01, etc.) with actual speaker names, and
|
||||
updates the recording's participants list.
|
||||
|
||||
Args:
|
||||
recording: Recording model instance with transcription
|
||||
speaker_map: Dict mapping {SPEAKER_XX: speaker_name}
|
||||
|
||||
Returns:
|
||||
bool: True if changes were made, False otherwise
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not speaker_map or not recording.transcription:
|
||||
logger.warning(f"Auto-label: No speaker_map or transcription (map={bool(speaker_map)}, trans={bool(recording.transcription)})")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Parse transcription as JSON array: [{speaker, sentence, start_time, end_time}, ...]
|
||||
segments = json.loads(recording.transcription)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Auto-label: Failed to parse transcription as JSON: {e}")
|
||||
return False
|
||||
|
||||
if not isinstance(segments, list) or not segments:
|
||||
logger.warning(f"Auto-label: Transcription not in expected array format")
|
||||
return False
|
||||
|
||||
# Track which speakers were renamed
|
||||
renamed_speakers = set()
|
||||
|
||||
# Update speaker labels in segments
|
||||
for segment in segments:
|
||||
if 'speaker' in segment and segment['speaker'] in speaker_map:
|
||||
segment['speaker'] = speaker_map[segment['speaker']]
|
||||
renamed_speakers.add(segment['speaker'])
|
||||
|
||||
if not renamed_speakers:
|
||||
logger.warning(f"Auto-label: No speakers matched in segments")
|
||||
return False
|
||||
|
||||
logger.info(f"Auto-label: Applied names to {len(renamed_speakers)} speakers: {renamed_speakers}")
|
||||
|
||||
# Update participants field
|
||||
all_speakers = set(s.get('speaker') for s in segments if 'speaker' in s)
|
||||
if all_speakers:
|
||||
recording.participants = ', '.join(sorted(all_speakers))
|
||||
|
||||
# Save updated transcription
|
||||
recording.transcription = json.dumps(segments)
|
||||
db.session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def update_speaker_profiles_from_recording(recording, speaker_map, user):
|
||||
"""
|
||||
Update speaker voice profiles with new embeddings from a recording.
|
||||
|
||||
For each successfully matched speaker, this function updates their
|
||||
average embedding and increments their usage count.
|
||||
|
||||
Args:
|
||||
recording: Recording model instance with speaker_embeddings
|
||||
speaker_map: Dict mapping {SPEAKER_XX: speaker_name} that was applied
|
||||
user: User model instance
|
||||
|
||||
Returns:
|
||||
int: Number of speaker profiles updated
|
||||
"""
|
||||
if not speaker_map or not recording.speaker_embeddings:
|
||||
return 0
|
||||
|
||||
updated_count = 0
|
||||
embeddings = recording.speaker_embeddings
|
||||
|
||||
for speaker_label, speaker_name in speaker_map.items():
|
||||
if speaker_label not in embeddings:
|
||||
continue
|
||||
|
||||
embedding_data = embeddings[speaker_label]
|
||||
if not embedding_data or not isinstance(embedding_data, list):
|
||||
continue
|
||||
|
||||
# Find the speaker profile
|
||||
speaker = Speaker.query.filter_by(
|
||||
user_id=user.id,
|
||||
name=speaker_name
|
||||
).first()
|
||||
|
||||
if not speaker:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Update the speaker's embedding with the new sample
|
||||
update_speaker_embedding(speaker, embedding_data, recording.id)
|
||||
|
||||
# Update usage tracking
|
||||
speaker.use_count = (speaker.use_count or 0) + 1
|
||||
speaker.last_used = datetime.utcnow()
|
||||
|
||||
updated_count += 1
|
||||
except Exception:
|
||||
# Skip if embedding update fails
|
||||
continue
|
||||
|
||||
if updated_count > 0:
|
||||
db.session.commit()
|
||||
|
||||
return updated_count
|
||||
228
src/services/speaker_identification.py
Normal file
228
src/services/speaker_identification.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Shared speaker identification service.
|
||||
|
||||
Provides LLM-based speaker identification from transcript context,
|
||||
used by both the web UI (recordings.py) and REST API (api_v1.py).
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from flask import current_app
|
||||
|
||||
|
||||
def identify_speakers_from_transcript(transcription_data, user_id):
|
||||
"""
|
||||
Identify speakers in a transcription using an LLM.
|
||||
|
||||
Args:
|
||||
transcription_data: List of transcript segments (already parsed JSON).
|
||||
user_id: Current user's ID (for token tracking).
|
||||
|
||||
Returns:
|
||||
dict mapping original speaker labels to identified names.
|
||||
Values are empty string "" for unidentified speakers.
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM API key is not configured.
|
||||
Exception: On LLM call failure.
|
||||
"""
|
||||
from src.services.llm import call_llm_completion
|
||||
from src.utils import safe_json_loads
|
||||
from src.models import SystemSetting
|
||||
|
||||
# Extract unique speakers in order of appearance
|
||||
seen_speakers = set()
|
||||
unique_speakers = []
|
||||
for segment in transcription_data:
|
||||
speaker = segment.get('speaker')
|
||||
if speaker and speaker not in seen_speakers:
|
||||
seen_speakers.add(speaker)
|
||||
unique_speakers.append(speaker)
|
||||
|
||||
if not unique_speakers:
|
||||
return {}
|
||||
|
||||
# Normalize all labels to SPEAKER_XX format for the LLM
|
||||
speaker_to_label = {}
|
||||
for idx, speaker in enumerate(unique_speakers):
|
||||
speaker_to_label[speaker] = f'SPEAKER_{str(idx).zfill(2)}'
|
||||
|
||||
# Create temporary transcript with normalized labels
|
||||
formatted_lines = []
|
||||
for segment in transcription_data:
|
||||
original_speaker = segment.get('speaker')
|
||||
label = speaker_to_label.get(original_speaker, 'Unknown Speaker')
|
||||
sentence = segment.get('sentence', '')
|
||||
formatted_lines.append(f"[{label}]: {sentence}")
|
||||
formatted_transcription = "\n".join(formatted_lines)
|
||||
|
||||
speaker_labels = list(speaker_to_label.values())
|
||||
|
||||
current_app.logger.info(f"[Auto-Identify] Formatted transcript (first 500 chars): {formatted_transcription[:500]}")
|
||||
current_app.logger.info(f"[Auto-Identify] Speaker labels: {speaker_labels}")
|
||||
|
||||
# Apply configurable transcript length limit
|
||||
transcript_limit = SystemSetting.get_setting('transcript_length_limit', 30000)
|
||||
if transcript_limit == -1:
|
||||
transcript_text = formatted_transcription
|
||||
else:
|
||||
transcript_text = formatted_transcription[:transcript_limit]
|
||||
|
||||
prompt = f"""Analyse cette transcription de conversation et identifie les noms des locuteurs à partir du contexte et du contenu de leurs dialogues.
|
||||
|
||||
Les locuteurs à identifier sont : {', '.join(speaker_labels)}
|
||||
|
||||
Indices à chercher :
|
||||
- Noms mentionnés par d'autres locuteurs quand ils s'adressent à quelqu'un
|
||||
- Présentations ou références à son propre nom
|
||||
- Indices contextuels sur les rôles, relations ou postes
|
||||
- Toute mention directe de noms dans le dialogue
|
||||
|
||||
Transcription complète :
|
||||
|
||||
{transcript_text}
|
||||
|
||||
À partir de cette conversation, identifie les noms les plus probables pour chaque locuteur. Porte une attention particulière à la façon dont les locuteurs s'adressent les uns aux autres.
|
||||
|
||||
Réponds avec un seul objet JSON où les clés sont les étiquettes de locuteurs (ex. "SPEAKER_01") et les valeurs sont les noms complets identifiés. Si un nom ne peut pas être déterminé, utilise une chaîne vide "".
|
||||
|
||||
Exemple :
|
||||
{{
|
||||
"SPEAKER_01": "Marie Lavoie",
|
||||
"SPEAKER_03": "Jean Tremblay",
|
||||
"SPEAKER_05": ""
|
||||
}}
|
||||
|
||||
Réponse JSON :
|
||||
"""
|
||||
|
||||
current_app.logger.info("[Auto-Identify] Calling LLM")
|
||||
|
||||
use_schema = os.environ.get('AUTO_IDENTIFY_RESPONSE_SCHEMA', '').strip() in ('1', 'true', 'yes')
|
||||
system_msg = (
|
||||
"You are an expert in analyzing conversation transcripts to identify speakers "
|
||||
"based on contextual clues in the dialogue. Analyze the conversation carefully "
|
||||
"to find names mentioned when speakers address each other or introduce themselves. "
|
||||
"Your response must be a single, valid JSON object containing only the requested "
|
||||
"speaker identifications."
|
||||
)
|
||||
|
||||
response_content = None
|
||||
if use_schema:
|
||||
# Build JSON schema response format with constrained keys
|
||||
schema_properties = {label: {"type": "string"} for label in speaker_labels}
|
||||
schema_response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "speaker_identification",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": schema_properties,
|
||||
"required": speaker_labels,
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
schema_prompt = prompt + f"\n\nIMPORTANT: Your JSON response must contain exactly these keys: {', '.join(speaker_labels)}"
|
||||
try:
|
||||
current_app.logger.info("[Auto-Identify] Trying json_schema response format")
|
||||
completion = call_llm_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": schema_prompt}
|
||||
],
|
||||
temperature=0.2,
|
||||
response_format=schema_response_format,
|
||||
user_id=user_id,
|
||||
operation_type='speaker_identification'
|
||||
)
|
||||
response_content = completion.choices[0].message.content
|
||||
current_app.logger.info(f"[Auto-Identify] LLM Raw Response (schema mode): {response_content}")
|
||||
except Exception as schema_err:
|
||||
current_app.logger.warning(f"[Auto-Identify] json_schema mode failed, falling back to json_object: {schema_err}")
|
||||
response_content = None
|
||||
|
||||
if response_content is None:
|
||||
completion = call_llm_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.2,
|
||||
user_id=user_id,
|
||||
operation_type='speaker_identification'
|
||||
)
|
||||
response_content = completion.choices[0].message.content
|
||||
current_app.logger.info(f"[Auto-Identify] LLM Raw Response: {response_content}")
|
||||
|
||||
identified_map = safe_json_loads(response_content, {})
|
||||
current_app.logger.info(f"[Auto-Identify] Parsed identified_map: {identified_map}")
|
||||
|
||||
# --- Sanitize identified_map ---
|
||||
identified_map = _sanitize_identified_map(identified_map, speaker_labels)
|
||||
current_app.logger.info(f"[Auto-Identify] Sanitized identified_map: {identified_map}")
|
||||
|
||||
# Map back to original speaker labels
|
||||
final_speaker_map = {}
|
||||
for original_speaker, temp_label in speaker_to_label.items():
|
||||
if temp_label in identified_map:
|
||||
final_speaker_map[original_speaker] = identified_map[temp_label]
|
||||
|
||||
current_app.logger.info(f"[Auto-Identify] Final speaker_map: {final_speaker_map}")
|
||||
return final_speaker_map
|
||||
|
||||
|
||||
def _sanitize_identified_map(identified_map, speaker_labels):
|
||||
"""
|
||||
Clean up LLM output: handle inverted maps, strip commentary,
|
||||
clear placeholders, etc.
|
||||
"""
|
||||
speaker_label_re = re.compile(r'^SPEAKER_\d{2}$')
|
||||
|
||||
# Detect inverted map ({name: "SPEAKER_XX"}) and flip it
|
||||
if identified_map and all(
|
||||
speaker_label_re.match(str(v)) for v in identified_map.values() if v
|
||||
) and not any(speaker_label_re.match(str(k)) for k in identified_map.keys()):
|
||||
current_app.logger.warning("[Auto-Identify] Detected inverted map, flipping keys/values")
|
||||
identified_map = {v: k for k, v in identified_map.items() if v}
|
||||
|
||||
sanitized = {}
|
||||
for speaker_label, identified_name in identified_map.items():
|
||||
# Skip entries whose key isn't a valid SPEAKER_XX label
|
||||
if not speaker_label_re.match(str(speaker_label)):
|
||||
continue
|
||||
if not identified_name or not isinstance(identified_name, str):
|
||||
sanitized[speaker_label] = ""
|
||||
continue
|
||||
|
||||
name = identified_name.strip()
|
||||
|
||||
# Clear generic placeholders
|
||||
if name.lower() in ["unknown", "n/a", "not available", "unclear", "unidentified", ""]:
|
||||
sanitized[speaker_label] = ""
|
||||
continue
|
||||
|
||||
# Clear label-to-label entries (e.g. "SPEAKER_01": "SPEAKER_02")
|
||||
if speaker_label_re.match(name):
|
||||
sanitized[speaker_label] = ""
|
||||
continue
|
||||
|
||||
# Strip parenthetical content: "John (the host)" -> "John"
|
||||
name = re.sub(r'\s*\([^)]*\)', '', name).strip()
|
||||
|
||||
# Take first name segment before comma, semicolon, or slash
|
||||
name = re.split(r'[,;/]', name)[0].strip()
|
||||
|
||||
# Collapse whitespace
|
||||
name = re.sub(r'\s+', ' ', name)
|
||||
|
||||
# Final check: if result still matches SPEAKER_XX, clear it
|
||||
if speaker_label_re.match(name) or not name:
|
||||
sanitized[speaker_label] = ""
|
||||
continue
|
||||
|
||||
sanitized[speaker_label] = name
|
||||
|
||||
return sanitized
|
||||
226
src/services/speaker_merge.py
Normal file
226
src/services/speaker_merge.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Speaker Merge Service.
|
||||
|
||||
This service handles merging multiple speaker profiles into one.
|
||||
Useful when users accidentally create duplicate speakers for the same person.
|
||||
|
||||
When speakers are merged:
|
||||
- Voice embeddings are combined using weighted average
|
||||
- All snippets are transferred to the target speaker
|
||||
- Usage statistics are combined
|
||||
- Source speakers are deleted
|
||||
- Confidence score is recalculated
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from src.database import db
|
||||
from src.models import Speaker, SpeakerSnippet
|
||||
from src.services.speaker_embedding_matcher import (
|
||||
serialize_embedding,
|
||||
deserialize_embedding,
|
||||
calculate_confidence
|
||||
)
|
||||
|
||||
|
||||
def merge_speakers(target_id, source_ids, user_id):
|
||||
"""
|
||||
Merge multiple speaker profiles into one target speaker.
|
||||
|
||||
All embeddings, snippets, and usage data from source speakers are
|
||||
combined into the target speaker. Source speakers are then deleted.
|
||||
|
||||
Args:
|
||||
target_id: ID of the speaker to keep (receives all merged data)
|
||||
source_ids: List of speaker IDs to merge into target
|
||||
user_id: ID of the user (for security check)
|
||||
|
||||
Returns:
|
||||
Speaker: The updated target speaker
|
||||
|
||||
Raises:
|
||||
ValueError: If speakers don't exist or don't belong to user
|
||||
"""
|
||||
# Validate target speaker
|
||||
target = Speaker.query.filter_by(id=target_id, user_id=user_id).first()
|
||||
if not target:
|
||||
raise ValueError(f"Target speaker {target_id} not found or doesn't belong to user")
|
||||
|
||||
# Validate source speakers
|
||||
sources = Speaker.query.filter(
|
||||
Speaker.id.in_(source_ids),
|
||||
Speaker.user_id == user_id
|
||||
).all()
|
||||
|
||||
if len(sources) == 0:
|
||||
raise ValueError("No valid source speakers found")
|
||||
|
||||
if len(sources) != len(source_ids):
|
||||
raise ValueError("Some source speakers don't exist or don't belong to user")
|
||||
|
||||
# Can't merge a speaker with itself
|
||||
if target_id in source_ids:
|
||||
raise ValueError("Cannot merge a speaker with itself")
|
||||
|
||||
# Combine embeddings
|
||||
_combine_embeddings(target, sources)
|
||||
|
||||
# Transfer snippets
|
||||
for source in sources:
|
||||
SpeakerSnippet.query.filter_by(speaker_id=source.id).update(
|
||||
{'speaker_id': target_id}
|
||||
)
|
||||
|
||||
# Combine usage statistics
|
||||
for source in sources:
|
||||
target.use_count += source.use_count
|
||||
|
||||
# Update last_used to most recent
|
||||
if source.last_used and (not target.last_used or source.last_used > target.last_used):
|
||||
target.last_used = source.last_used
|
||||
|
||||
# Combine embedding histories
|
||||
if source.embeddings_history:
|
||||
target_history = target.embeddings_history or []
|
||||
source_history = source.embeddings_history or []
|
||||
combined_history = target_history + source_history
|
||||
|
||||
# Sort by timestamp (most recent last) and keep last 10
|
||||
try:
|
||||
combined_history.sort(key=lambda x: x.get('timestamp', ''))
|
||||
target.embeddings_history = combined_history[-10:]
|
||||
except:
|
||||
# If sorting fails, just concatenate and truncate
|
||||
target.embeddings_history = (target_history + source_history)[-10:]
|
||||
|
||||
# Recalculate confidence score
|
||||
target.confidence_score = calculate_confidence(target)
|
||||
|
||||
# Delete source speakers
|
||||
for source in sources:
|
||||
db.session.delete(source)
|
||||
|
||||
# Commit all changes
|
||||
db.session.commit()
|
||||
|
||||
return target
|
||||
|
||||
|
||||
def _combine_embeddings(target, sources):
|
||||
"""
|
||||
Combine embeddings from multiple speakers using weighted average.
|
||||
|
||||
Weight is based on embedding_count (more samples = more weight).
|
||||
|
||||
Args:
|
||||
target: Target Speaker instance
|
||||
sources: List of source Speaker instances
|
||||
"""
|
||||
all_embeddings = []
|
||||
all_counts = []
|
||||
|
||||
# Add target's embedding if it exists
|
||||
if target.average_embedding:
|
||||
all_embeddings.append(deserialize_embedding(target.average_embedding))
|
||||
all_counts.append(target.embedding_count or 1)
|
||||
|
||||
# Add all source embeddings
|
||||
for source in sources:
|
||||
if source.average_embedding:
|
||||
all_embeddings.append(deserialize_embedding(source.average_embedding))
|
||||
all_counts.append(source.embedding_count or 1)
|
||||
|
||||
if not all_embeddings:
|
||||
# No embeddings to combine
|
||||
return
|
||||
|
||||
# Calculate weighted average
|
||||
total_count = sum(all_counts)
|
||||
weights = [c / total_count for c in all_counts]
|
||||
|
||||
combined_emb = np.average(all_embeddings, axis=0, weights=weights)
|
||||
|
||||
# Update target
|
||||
target.average_embedding = serialize_embedding(combined_emb)
|
||||
target.embedding_count = total_count
|
||||
|
||||
|
||||
def preview_merge(target_id, source_ids, user_id):
|
||||
"""
|
||||
Preview what a merge would look like without executing it.
|
||||
|
||||
Args:
|
||||
target_id: ID of the target speaker
|
||||
source_ids: List of source speaker IDs
|
||||
user_id: ID of the user
|
||||
|
||||
Returns:
|
||||
dict: Preview of the merge results
|
||||
{
|
||||
'target_name': '...',
|
||||
'source_names': [...],
|
||||
'combined_use_count': 123,
|
||||
'combined_embedding_count': 45,
|
||||
'total_snippets': 67
|
||||
}
|
||||
"""
|
||||
# Validate speakers
|
||||
target = Speaker.query.filter_by(id=target_id, user_id=user_id).first()
|
||||
if not target:
|
||||
raise ValueError("Target speaker not found")
|
||||
|
||||
sources = Speaker.query.filter(
|
||||
Speaker.id.in_(source_ids),
|
||||
Speaker.user_id == user_id
|
||||
).all()
|
||||
|
||||
if len(sources) == 0:
|
||||
raise ValueError("No valid source speakers found")
|
||||
|
||||
# Calculate combined statistics
|
||||
combined_use_count = target.use_count
|
||||
combined_embedding_count = target.embedding_count or 0
|
||||
total_snippets = SpeakerSnippet.query.filter_by(speaker_id=target_id).count()
|
||||
|
||||
source_names = []
|
||||
for source in sources:
|
||||
combined_use_count += source.use_count
|
||||
combined_embedding_count += (source.embedding_count or 0)
|
||||
total_snippets += SpeakerSnippet.query.filter_by(speaker_id=source.id).count()
|
||||
source_names.append(source.name)
|
||||
|
||||
return {
|
||||
'target_name': target.name,
|
||||
'source_names': source_names,
|
||||
'combined_use_count': combined_use_count,
|
||||
'combined_embedding_count': combined_embedding_count,
|
||||
'total_snippets': total_snippets,
|
||||
'has_embeddings': target.average_embedding is not None or any(s.average_embedding for s in sources)
|
||||
}
|
||||
|
||||
|
||||
def can_merge_speakers(speaker_ids, user_id):
|
||||
"""
|
||||
Check if speakers can be merged (all belong to same user, no duplicates).
|
||||
|
||||
Args:
|
||||
speaker_ids: List of speaker IDs
|
||||
user_id: ID of the user
|
||||
|
||||
Returns:
|
||||
tuple: (bool, str) - (can_merge, error_message)
|
||||
"""
|
||||
if len(speaker_ids) < 2:
|
||||
return False, "Need at least 2 speakers to merge"
|
||||
|
||||
if len(speaker_ids) != len(set(speaker_ids)):
|
||||
return False, "Duplicate speaker IDs provided"
|
||||
|
||||
speakers = Speaker.query.filter(
|
||||
Speaker.id.in_(speaker_ids),
|
||||
Speaker.user_id == user_id
|
||||
).all()
|
||||
|
||||
if len(speakers) != len(speaker_ids):
|
||||
return False, "Some speakers don't exist or don't belong to user"
|
||||
|
||||
return True, ""
|
||||
371
src/services/speaker_snippets.py
Normal file
371
src/services/speaker_snippets.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
Speaker Snippets Service.
|
||||
|
||||
This service handles the extraction and management of representative speech snippets
|
||||
from recordings. Snippets provide context when viewing speaker profiles and help
|
||||
users verify speaker identifications.
|
||||
|
||||
Key functions:
|
||||
- Extract snippets when speakers are identified in recordings
|
||||
- Retrieve snippets for display in speaker profiles
|
||||
- Clean up old snippets to prevent database bloat
|
||||
"""
|
||||
|
||||
import json
|
||||
from src.database import db
|
||||
from src.models import Speaker, SpeakerSnippet, Recording
|
||||
|
||||
MAX_SNIPPETS_PER_SPEAKER = 7
|
||||
MAX_SNIPPETS_PER_RECORDING = 2
|
||||
|
||||
|
||||
def create_speaker_snippets(recording_id, speaker_map):
|
||||
"""
|
||||
Extract and store representative snippets for each identified speaker.
|
||||
|
||||
This function is called after a user saves speaker identifications in a recording.
|
||||
It extracts up to MAX_SNIPPETS_PER_RECORDING quotes per speaker from this recording,
|
||||
and enforces a global cap of MAX_SNIPPETS_PER_SPEAKER by evicting the oldest.
|
||||
|
||||
Args:
|
||||
recording_id: ID of the recording
|
||||
speaker_map: Dict mapping SPEAKER_XX to speaker info
|
||||
{'SPEAKER_00': {'name': 'John Doe', 'isMe': False}, ...}
|
||||
|
||||
Returns:
|
||||
int: Number of snippets created
|
||||
"""
|
||||
recording = Recording.query.get(recording_id)
|
||||
if not recording or not recording.transcription:
|
||||
return 0
|
||||
|
||||
try:
|
||||
transcript = json.loads(recording.transcription)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return 0
|
||||
|
||||
# Build a reverse map: assigned name -> speaker_info
|
||||
# After transcript is saved, segment['speaker'] contains the real name,
|
||||
# not the original SPEAKER_XX label. We need to match by name too.
|
||||
name_to_info = {}
|
||||
for label, info in speaker_map.items():
|
||||
name = info.get('name', '').strip()
|
||||
if name and not name.startswith('SPEAKER_'):
|
||||
name_to_info[name] = info
|
||||
|
||||
# Collect candidates per speaker: (speaker_obj, segment_idx, text, timestamp)
|
||||
candidates = {} # speaker_id -> list of (segment_idx, text, timestamp)
|
||||
|
||||
for segment_idx, segment in enumerate(transcript):
|
||||
speaker_field = segment.get('speaker')
|
||||
|
||||
if not speaker_field:
|
||||
continue
|
||||
|
||||
# Try matching by original label first, then by assigned name
|
||||
if speaker_field in speaker_map:
|
||||
speaker_info = speaker_map[speaker_field]
|
||||
speaker_name = speaker_info.get('name')
|
||||
elif speaker_field in name_to_info:
|
||||
speaker_name = speaker_field
|
||||
else:
|
||||
continue
|
||||
|
||||
if not speaker_name or speaker_name.startswith('SPEAKER_'):
|
||||
continue
|
||||
|
||||
# Find the speaker in database
|
||||
speaker = Speaker.query.filter_by(
|
||||
user_id=recording.user_id,
|
||||
name=speaker_name
|
||||
).first()
|
||||
|
||||
if not speaker:
|
||||
continue
|
||||
|
||||
text = segment.get('sentence', '').strip()
|
||||
if len(text) < 10:
|
||||
continue
|
||||
|
||||
if speaker.id not in candidates:
|
||||
candidates[speaker.id] = []
|
||||
candidates[speaker.id].append((segment_idx, text[:200], segment.get('start_time')))
|
||||
|
||||
# Delete existing snippets for this recording (re-save replaces them)
|
||||
SpeakerSnippet.query.filter_by(recording_id=recording_id).delete()
|
||||
|
||||
snippets_created = 0
|
||||
|
||||
for speaker_id, segs in candidates.items():
|
||||
# Pick up to MAX_SNIPPETS_PER_RECORDING spread across the transcript
|
||||
if len(segs) <= MAX_SNIPPETS_PER_RECORDING:
|
||||
chosen = segs
|
||||
else:
|
||||
# Evenly sample from the segments
|
||||
step = len(segs) / MAX_SNIPPETS_PER_RECORDING
|
||||
chosen = [segs[int(i * step)] for i in range(MAX_SNIPPETS_PER_RECORDING)]
|
||||
|
||||
for segment_idx, text_snippet, timestamp in chosen:
|
||||
# Evict oldest if at global cap
|
||||
global_count = SpeakerSnippet.query.filter_by(speaker_id=speaker_id).count()
|
||||
if global_count >= MAX_SNIPPETS_PER_SPEAKER:
|
||||
oldest = SpeakerSnippet.query.filter_by(speaker_id=speaker_id)\
|
||||
.order_by(SpeakerSnippet.created_at.asc()).first()
|
||||
if oldest:
|
||||
db.session.delete(oldest)
|
||||
db.session.flush()
|
||||
|
||||
snippet = SpeakerSnippet(
|
||||
speaker_id=speaker_id,
|
||||
recording_id=recording_id,
|
||||
segment_index=segment_idx,
|
||||
text_snippet=text_snippet,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
db.session.add(snippet)
|
||||
snippets_created += 1
|
||||
|
||||
# Flush after each speaker batch to keep counts accurate
|
||||
db.session.flush()
|
||||
|
||||
if snippets_created > 0:
|
||||
db.session.commit()
|
||||
|
||||
return snippets_created
|
||||
|
||||
|
||||
def _generate_dynamic_snippets(speaker_id, limit=3):
|
||||
"""
|
||||
Dynamically generate audio snippets from a speaker's recent recordings.
|
||||
|
||||
This function finds short audio segments (3-4 seconds) from recent recordings
|
||||
where the speaker appears. These can be played back to verify speaker identity.
|
||||
|
||||
Args:
|
||||
speaker_id: ID of the speaker
|
||||
limit: Maximum number of snippets to return (default 3)
|
||||
|
||||
Returns:
|
||||
list: List of snippet dictionaries with audio segment information
|
||||
[{'recording_id': 123, 'start_time': 45.2, 'duration': 3.5, ...}, ...]
|
||||
"""
|
||||
# Get the speaker
|
||||
speaker = Speaker.query.get(speaker_id)
|
||||
if not speaker:
|
||||
return []
|
||||
|
||||
# Find recordings that have this speaker's name in transcription
|
||||
# We'll look at the last 10 recordings and extract snippets from them
|
||||
recordings = Recording.query.filter_by(user_id=speaker.user_id)\
|
||||
.filter(Recording.transcription.isnot(None))\
|
||||
.filter(Recording.transcription != '')\
|
||||
.filter(Recording.audio_deleted_at.is_(None))\
|
||||
.order_by(Recording.created_at.desc())\
|
||||
.limit(10).all()
|
||||
|
||||
snippets = []
|
||||
|
||||
for recording in recordings:
|
||||
if len(snippets) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
# Parse transcription JSON
|
||||
transcript = json.loads(recording.transcription)
|
||||
|
||||
if not isinstance(transcript, list):
|
||||
continue
|
||||
|
||||
# Find segments where this speaker appears
|
||||
speaker_segments = []
|
||||
for idx, segment in enumerate(transcript):
|
||||
# Check if segment has speaker identification matching our speaker's name
|
||||
speaker_label = segment.get('speaker')
|
||||
|
||||
# In identified transcripts, the speaker field contains the actual name
|
||||
if speaker_label != speaker.name:
|
||||
continue
|
||||
|
||||
start_time = segment.get('start_time')
|
||||
end_time = segment.get('end_time')
|
||||
|
||||
if start_time is None or end_time is None:
|
||||
continue
|
||||
|
||||
duration = end_time - start_time
|
||||
|
||||
# Skip very short segments (less than 2 seconds)
|
||||
if duration < 2.0:
|
||||
continue
|
||||
|
||||
speaker_segments.append({
|
||||
'index': idx,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time,
|
||||
'duration': duration,
|
||||
'text': segment.get('sentence', '').strip()[:100] # Preview text
|
||||
})
|
||||
|
||||
if not speaker_segments:
|
||||
continue
|
||||
|
||||
# Take snippets from middle portions (skip first and last 10%)
|
||||
total_segments = len(speaker_segments)
|
||||
if total_segments > 4:
|
||||
# Skip first and last 10%
|
||||
start_idx = max(1, int(total_segments * 0.1))
|
||||
end_idx = min(total_segments - 1, int(total_segments * 0.9))
|
||||
middle_segments = speaker_segments[start_idx:end_idx]
|
||||
else:
|
||||
middle_segments = speaker_segments
|
||||
|
||||
# Take 1 snippet per recording from the middle
|
||||
if middle_segments:
|
||||
# Pick a segment from the middle
|
||||
middle_idx = len(middle_segments) // 2
|
||||
segment = middle_segments[middle_idx]
|
||||
|
||||
# Limit audio snippet to 3-4 seconds
|
||||
snippet_duration = min(4.0, segment['duration'])
|
||||
|
||||
snippets.append({
|
||||
'id': None, # Dynamic snippet, no database ID
|
||||
'speaker_id': speaker_id,
|
||||
'recording_id': recording.id,
|
||||
'start_time': segment['start_time'],
|
||||
'duration': snippet_duration,
|
||||
'text': segment['text'], # Preview text for context
|
||||
'recording_title': recording.title or 'Untitled Recording',
|
||||
'created_at': recording.created_at.isoformat() if recording.created_at else None
|
||||
})
|
||||
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
# Skip recordings with invalid transcription format
|
||||
continue
|
||||
|
||||
return snippets
|
||||
|
||||
|
||||
def get_speaker_snippets(speaker_id, limit=3):
|
||||
"""
|
||||
Get recent audio snippets for a speaker.
|
||||
|
||||
Returns short audio segments (3-4 seconds) from recent recordings where this
|
||||
speaker appears. These audio snippets can be played to verify speaker identity.
|
||||
|
||||
Args:
|
||||
speaker_id: ID of the speaker
|
||||
limit: Maximum number of snippets to return (default 3)
|
||||
|
||||
Returns:
|
||||
list: List of snippet dictionaries with audio segment information
|
||||
[{'recording_id': 123, 'start_time': 45.2, 'duration': 3.5, ...}, ...]
|
||||
"""
|
||||
# Always dynamically generate audio snippets from recent recordings
|
||||
return _generate_dynamic_snippets(speaker_id, limit)
|
||||
|
||||
|
||||
def get_snippets_by_recording(recording_id, speaker_id):
|
||||
"""
|
||||
Get all snippets for a specific speaker in a specific recording.
|
||||
|
||||
Args:
|
||||
recording_id: ID of the recording
|
||||
speaker_id: ID of the speaker
|
||||
|
||||
Returns:
|
||||
list: List of snippet dictionaries
|
||||
"""
|
||||
snippets = SpeakerSnippet.query.filter_by(
|
||||
recording_id=recording_id,
|
||||
speaker_id=speaker_id
|
||||
).order_by(SpeakerSnippet.segment_index).all()
|
||||
|
||||
return [snippet.to_dict() for snippet in snippets]
|
||||
|
||||
|
||||
def cleanup_old_snippets(speaker_id, keep=10):
|
||||
"""
|
||||
Clean up old snippets for a speaker, keeping only the most recent ones.
|
||||
|
||||
Args:
|
||||
speaker_id: ID of the speaker
|
||||
keep: Number of snippets to keep (default 10)
|
||||
|
||||
Returns:
|
||||
int: Number of snippets deleted
|
||||
"""
|
||||
# Get all snippets for this speaker, ordered by creation date
|
||||
all_snippets = SpeakerSnippet.query.filter_by(speaker_id=speaker_id)\
|
||||
.order_by(SpeakerSnippet.created_at.desc()).all()
|
||||
|
||||
if len(all_snippets) <= keep:
|
||||
return 0
|
||||
|
||||
# Delete old snippets beyond the keep limit
|
||||
snippets_to_delete = all_snippets[keep:]
|
||||
deleted_count = 0
|
||||
|
||||
for snippet in snippets_to_delete:
|
||||
db.session.delete(snippet)
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
db.session.commit()
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
def delete_snippets_for_recording(recording_id):
|
||||
"""
|
||||
Delete all snippets associated with a recording.
|
||||
|
||||
This is typically called when a recording is deleted or reprocessed.
|
||||
|
||||
Args:
|
||||
recording_id: ID of the recording
|
||||
|
||||
Returns:
|
||||
int: Number of snippets deleted
|
||||
"""
|
||||
deleted_count = SpeakerSnippet.query.filter_by(recording_id=recording_id).delete()
|
||||
db.session.commit()
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_speaker_recordings_with_snippets(speaker_id):
|
||||
"""
|
||||
Get a list of recordings that have snippets for this speaker.
|
||||
|
||||
Args:
|
||||
speaker_id: ID of the speaker
|
||||
|
||||
Returns:
|
||||
list: List of recording dictionaries with snippet counts
|
||||
[{'id': 123, 'title': '...', 'snippet_count': 3, 'date': '...'}, ...]
|
||||
"""
|
||||
# Get distinct recordings with snippet counts
|
||||
from sqlalchemy import func
|
||||
|
||||
recordings_with_counts = db.session.query(
|
||||
Recording.id,
|
||||
Recording.title,
|
||||
Recording.created_at,
|
||||
func.count(SpeakerSnippet.id).label('snippet_count')
|
||||
).join(
|
||||
SpeakerSnippet,
|
||||
Recording.id == SpeakerSnippet.recording_id
|
||||
).filter(
|
||||
SpeakerSnippet.speaker_id == speaker_id
|
||||
).group_by(
|
||||
Recording.id
|
||||
).order_by(
|
||||
Recording.created_at.desc()
|
||||
).all()
|
||||
|
||||
return [{
|
||||
'id': r.id,
|
||||
'title': r.title,
|
||||
'snippet_count': r.snippet_count,
|
||||
'created_at': r.created_at.isoformat() if r.created_at else None
|
||||
} for r in recordings_with_counts]
|
||||
270
src/services/token_tracking.py
Normal file
270
src/services/token_tracking.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Token usage tracking service for monitoring LLM API consumption and budget enforcement.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Tuple, Optional, Dict, List
|
||||
|
||||
from sqlalchemy import func, extract
|
||||
|
||||
from src.database import db
|
||||
from src.models.token_usage import TokenUsage
|
||||
from src.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenTracker:
|
||||
"""Service for recording and checking token usage."""
|
||||
|
||||
OPERATION_TYPES = [
|
||||
'summarization',
|
||||
'chat',
|
||||
'title_generation',
|
||||
'event_extraction',
|
||||
'query_routing',
|
||||
'query_enrichment'
|
||||
]
|
||||
|
||||
def record_usage(
|
||||
self,
|
||||
user_id: int,
|
||||
operation_type: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model_name: str = None,
|
||||
cost: float = None
|
||||
):
|
||||
"""
|
||||
Record token usage - upserts into daily aggregate.
|
||||
|
||||
Args:
|
||||
user_id: User ID who made the request
|
||||
operation_type: Type of operation (summarization, chat, etc.)
|
||||
prompt_tokens: Number of input tokens
|
||||
completion_tokens: Number of output tokens
|
||||
total_tokens: Total tokens (prompt + completion)
|
||||
model_name: Name of the model used
|
||||
cost: API cost if available (e.g., from OpenRouter)
|
||||
"""
|
||||
try:
|
||||
today = date.today()
|
||||
|
||||
# Find or create today's record for this user/operation
|
||||
usage = TokenUsage.query.filter_by(
|
||||
user_id=user_id,
|
||||
date=today,
|
||||
operation_type=operation_type
|
||||
).first()
|
||||
|
||||
if usage:
|
||||
# Update existing record
|
||||
usage.prompt_tokens += prompt_tokens
|
||||
usage.completion_tokens += completion_tokens
|
||||
usage.total_tokens += total_tokens
|
||||
usage.request_count += 1
|
||||
if cost:
|
||||
usage.cost += cost
|
||||
if model_name:
|
||||
usage.model_name = model_name # Update to latest model used
|
||||
else:
|
||||
# Create new record
|
||||
usage = TokenUsage(
|
||||
user_id=user_id,
|
||||
date=today,
|
||||
operation_type=operation_type,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
request_count=1,
|
||||
model_name=model_name,
|
||||
cost=cost or 0.0
|
||||
)
|
||||
db.session.add(usage)
|
||||
|
||||
db.session.commit()
|
||||
logger.debug(f"Recorded {total_tokens} tokens for user {user_id}, operation {operation_type}")
|
||||
return usage
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record token usage: {e}")
|
||||
db.session.rollback()
|
||||
return None
|
||||
|
||||
def get_monthly_usage(self, user_id: int, year: int = None, month: int = None) -> int:
|
||||
"""Get total tokens used by a user in a given month."""
|
||||
if year is None:
|
||||
year = date.today().year
|
||||
if month is None:
|
||||
month = date.today().month
|
||||
|
||||
result = db.session.query(func.sum(TokenUsage.total_tokens)).filter(
|
||||
TokenUsage.user_id == user_id,
|
||||
extract('year', TokenUsage.date) == year,
|
||||
extract('month', TokenUsage.date) == month
|
||||
).scalar()
|
||||
|
||||
return result or 0
|
||||
|
||||
def get_monthly_cost(self, user_id: int, year: int = None, month: int = None) -> float:
|
||||
"""Get total cost for a user in a given month."""
|
||||
if year is None:
|
||||
year = date.today().year
|
||||
if month is None:
|
||||
month = date.today().month
|
||||
|
||||
result = db.session.query(func.sum(TokenUsage.cost)).filter(
|
||||
TokenUsage.user_id == user_id,
|
||||
extract('year', TokenUsage.date) == year,
|
||||
extract('month', TokenUsage.date) == month
|
||||
).scalar()
|
||||
|
||||
return result or 0.0
|
||||
|
||||
def check_budget(self, user_id: int) -> Tuple[bool, float, Optional[str]]:
|
||||
"""
|
||||
Check if user is within budget.
|
||||
|
||||
Returns:
|
||||
(can_proceed, usage_percentage, message)
|
||||
- can_proceed: False if hard cap (100%) reached
|
||||
- usage_percentage: 0-100+
|
||||
- message: Warning/error message if applicable
|
||||
"""
|
||||
try:
|
||||
user = db.session.get(User, user_id)
|
||||
if not user or not user.monthly_token_budget:
|
||||
return (True, 0, None) # No budget = unlimited
|
||||
|
||||
current_usage = self.get_monthly_usage(user_id)
|
||||
budget = user.monthly_token_budget
|
||||
percentage = (current_usage / budget) * 100
|
||||
|
||||
if percentage >= 100:
|
||||
return (False, percentage,
|
||||
f"Monthly token budget exceeded ({percentage:.1f}%). Contact admin for more tokens.")
|
||||
elif percentage >= 80:
|
||||
return (True, percentage,
|
||||
f"Warning: {percentage:.1f}% of monthly token budget used.")
|
||||
else:
|
||||
return (True, percentage, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check budget for user {user_id}: {e}")
|
||||
# Fail open - allow the request if we can't check
|
||||
return (True, 0, None)
|
||||
|
||||
def get_daily_stats(self, days: int = 30, user_id: int = None) -> List[Dict]:
|
||||
"""Get daily token usage for charts."""
|
||||
start_date = date.today() - timedelta(days=days - 1)
|
||||
|
||||
query = db.session.query(
|
||||
TokenUsage.date,
|
||||
TokenUsage.operation_type,
|
||||
func.sum(TokenUsage.total_tokens).label('tokens'),
|
||||
func.sum(TokenUsage.cost).label('cost')
|
||||
).filter(TokenUsage.date >= start_date)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(TokenUsage.user_id == user_id)
|
||||
|
||||
results = query.group_by(TokenUsage.date, TokenUsage.operation_type).all()
|
||||
|
||||
# Organize by date
|
||||
stats = {}
|
||||
for r in results:
|
||||
date_str = r.date.isoformat()
|
||||
if date_str not in stats:
|
||||
stats[date_str] = {'date': date_str, 'total': 0, 'cost': 0.0, 'by_operation': {}}
|
||||
stats[date_str]['total'] += r.tokens or 0
|
||||
stats[date_str]['cost'] += r.cost or 0
|
||||
stats[date_str]['by_operation'][r.operation_type] = r.tokens or 0
|
||||
|
||||
# Fill in missing dates with zeros
|
||||
all_dates = []
|
||||
current = start_date
|
||||
while current <= date.today():
|
||||
date_str = current.isoformat()
|
||||
if date_str not in stats:
|
||||
stats[date_str] = {'date': date_str, 'total': 0, 'cost': 0.0, 'by_operation': {}}
|
||||
all_dates.append(date_str)
|
||||
current += timedelta(days=1)
|
||||
|
||||
return [stats[d] for d in sorted(all_dates)]
|
||||
|
||||
def get_monthly_stats(self, months: int = 12) -> List[Dict]:
|
||||
"""Get monthly token usage for charts."""
|
||||
results = db.session.query(
|
||||
extract('year', TokenUsage.date).label('year'),
|
||||
extract('month', TokenUsage.date).label('month'),
|
||||
func.sum(TokenUsage.total_tokens).label('tokens'),
|
||||
func.sum(TokenUsage.cost).label('cost')
|
||||
).group_by('year', 'month').order_by('year', 'month').all()
|
||||
|
||||
# Get last N months
|
||||
monthly_data = [
|
||||
{
|
||||
'year': int(r.year),
|
||||
'month': int(r.month),
|
||||
'tokens': r.tokens or 0,
|
||||
'cost': r.cost or 0
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return monthly_data[-months:] if len(monthly_data) > months else monthly_data
|
||||
|
||||
def get_user_stats(self) -> List[Dict]:
|
||||
"""Get per-user token usage breakdown for current month."""
|
||||
today = date.today()
|
||||
|
||||
results = db.session.query(
|
||||
User.id,
|
||||
User.username,
|
||||
User.monthly_token_budget,
|
||||
func.sum(TokenUsage.total_tokens).label('usage'),
|
||||
func.sum(TokenUsage.cost).label('cost')
|
||||
).outerjoin(
|
||||
TokenUsage,
|
||||
(User.id == TokenUsage.user_id) &
|
||||
(extract('year', TokenUsage.date) == today.year) &
|
||||
(extract('month', TokenUsage.date) == today.month)
|
||||
).group_by(User.id).all()
|
||||
|
||||
return [
|
||||
{
|
||||
'user_id': r.id,
|
||||
'username': r.username,
|
||||
'monthly_budget': r.monthly_token_budget,
|
||||
'current_usage': r.usage or 0,
|
||||
'cost': r.cost or 0,
|
||||
'percentage': ((r.usage or 0) / r.monthly_token_budget * 100)
|
||||
if r.monthly_token_budget else 0
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
def get_today_usage(self, user_id: int = None) -> Dict:
|
||||
"""Get today's token usage."""
|
||||
today = date.today()
|
||||
|
||||
query = db.session.query(
|
||||
func.sum(TokenUsage.total_tokens).label('tokens'),
|
||||
func.sum(TokenUsage.cost).label('cost')
|
||||
).filter(TokenUsage.date == today)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(TokenUsage.user_id == user_id)
|
||||
|
||||
result = query.first()
|
||||
|
||||
return {
|
||||
'tokens': result.tokens or 0,
|
||||
'cost': result.cost or 0
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
token_tracker = TokenTracker()
|
||||
98
src/services/transcription/__init__.py
Normal file
98
src/services/transcription/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Transcription service package.
|
||||
|
||||
Provides a connector-based architecture for speech-to-text transcription
|
||||
with support for multiple providers:
|
||||
|
||||
- OpenAI Whisper (whisper-1)
|
||||
- OpenAI GPT-4o Transcribe (gpt-4o-transcribe, gpt-4o-mini-transcribe, gpt-4o-transcribe-diarize)
|
||||
- Custom ASR endpoints (whisper-asr-webservice, WhisperX, etc.)
|
||||
|
||||
Usage:
|
||||
from src.services.transcription import (
|
||||
transcribe,
|
||||
get_connector,
|
||||
supports_diarization,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
|
||||
# Simple transcription using active connector
|
||||
with open('audio.mp3', 'rb') as f:
|
||||
request = TranscriptionRequest(
|
||||
audio_file=f,
|
||||
filename='audio.mp3',
|
||||
diarize=True
|
||||
)
|
||||
response = transcribe(request)
|
||||
print(response.text)
|
||||
if response.segments:
|
||||
for seg in response.segments:
|
||||
print(f"[{seg.speaker}]: {seg.text}")
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
TranscriptionCapability,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionSegment,
|
||||
BaseTranscriptionConnector,
|
||||
ConnectorSpecifications,
|
||||
DEFAULT_SPECIFICATIONS,
|
||||
)
|
||||
|
||||
from .exceptions import (
|
||||
TranscriptionError,
|
||||
ConfigurationError,
|
||||
ProviderError,
|
||||
AudioFormatError,
|
||||
ChunkingError,
|
||||
)
|
||||
|
||||
from .registry import (
|
||||
ConnectorRegistry,
|
||||
get_registry,
|
||||
connector_registry,
|
||||
transcribe,
|
||||
get_connector,
|
||||
supports_diarization,
|
||||
)
|
||||
|
||||
from .connectors import (
|
||||
OpenAIWhisperConnector,
|
||||
OpenAITranscribeConnector,
|
||||
ASREndpointConnector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
'TranscriptionCapability',
|
||||
'TranscriptionRequest',
|
||||
'TranscriptionResponse',
|
||||
'TranscriptionSegment',
|
||||
'BaseTranscriptionConnector',
|
||||
'ConnectorSpecifications',
|
||||
'DEFAULT_SPECIFICATIONS',
|
||||
|
||||
# Exceptions
|
||||
'TranscriptionError',
|
||||
'ConfigurationError',
|
||||
'ProviderError',
|
||||
'AudioFormatError',
|
||||
'ChunkingError',
|
||||
|
||||
# Registry
|
||||
'ConnectorRegistry',
|
||||
'get_registry',
|
||||
'connector_registry',
|
||||
|
||||
# Convenience functions
|
||||
'transcribe',
|
||||
'get_connector',
|
||||
'supports_diarization',
|
||||
|
||||
# Connectors
|
||||
'OpenAIWhisperConnector',
|
||||
'OpenAITranscribeConnector',
|
||||
'ASREndpointConnector',
|
||||
]
|
||||
243
src/services/transcription/base.py
Normal file
243
src/services/transcription/base.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Base classes and data types for transcription connectors.
|
||||
"""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, List, Dict, Any, BinaryIO, Set, Type, FrozenSet
|
||||
|
||||
|
||||
class TranscriptionCapability(Enum):
|
||||
"""Capabilities that connectors can declare support for."""
|
||||
DIARIZATION = auto() # Speaker diarization
|
||||
CHUNKING = auto() # Automatic file chunking for large files
|
||||
TIMESTAMPS = auto() # Word/segment timestamps
|
||||
LANGUAGE_DETECTION = auto() # Auto language detection
|
||||
KNOWN_SPEAKERS = auto() # Support for known speaker references (future)
|
||||
SPEAKER_EMBEDDINGS = auto() # Return speaker embeddings
|
||||
SPEAKER_COUNT_CONTROL = auto() # Support for min/max speaker count parameters
|
||||
STREAMING = auto() # Real-time streaming transcription
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectorSpecifications:
|
||||
"""
|
||||
Provider-specific constraints and requirements.
|
||||
|
||||
Each connector declares its constraints so the application can automatically
|
||||
handle chunking, format conversion, and other preprocessing as needed.
|
||||
"""
|
||||
# Size constraints
|
||||
max_file_size_bytes: Optional[int] = None # None = unlimited
|
||||
|
||||
# Duration constraints
|
||||
max_duration_seconds: Optional[int] = None # None = unlimited
|
||||
min_duration_for_chunking: Optional[int] = None # Provider's internal chunking threshold
|
||||
|
||||
# Chunking behavior
|
||||
handles_chunking_internally: bool = False # Provider handles large files
|
||||
requires_chunking_param: bool = False # Must send chunking_strategy param
|
||||
recommended_chunk_seconds: int = 600 # 10 minutes default
|
||||
|
||||
# Audio format support - connector-specific codec restrictions
|
||||
# None = use system defaults from get_supported_codecs()
|
||||
# Set = only allow these codecs (overrides defaults)
|
||||
supported_codecs: Optional[FrozenSet[str]] = None
|
||||
# Codecs this connector doesn't support (removed from defaults)
|
||||
# Merged with AUDIO_UNSUPPORTED_CODECS env var
|
||||
unsupported_codecs: Optional[FrozenSet[str]] = None
|
||||
|
||||
|
||||
# Default specifications for connectors that don't define their own
|
||||
DEFAULT_SPECIFICATIONS = ConnectorSpecifications()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionRequest:
|
||||
"""Standardized transcription request."""
|
||||
audio_file: BinaryIO
|
||||
filename: str
|
||||
mime_type: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
||||
# Diarization options
|
||||
diarize: bool = False
|
||||
min_speakers: Optional[int] = None
|
||||
max_speakers: Optional[int] = None
|
||||
known_speaker_names: Optional[List[str]] = None
|
||||
# known_speaker_references: Dict mapping speaker label to either BinaryIO or data URL string
|
||||
known_speaker_references: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Advanced options
|
||||
prompt: Optional[str] = None
|
||||
hotwords: Optional[str] = None # Comma-separated words to bias recognition
|
||||
temperature: Optional[float] = None
|
||||
|
||||
# Provider-specific options (passthrough)
|
||||
extra_options: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionSegment:
|
||||
"""Single segment of transcription with optional metadata."""
|
||||
text: str
|
||||
speaker: Optional[str] = None
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
confidence: Optional[float] = None
|
||||
words: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResponse:
|
||||
"""Standardized transcription response."""
|
||||
# Core content
|
||||
text: str # Plain text transcription
|
||||
segments: Optional[List[TranscriptionSegment]] = None # Detailed segments
|
||||
|
||||
# Metadata
|
||||
language: Optional[str] = None # Detected language
|
||||
duration: Optional[float] = None # Audio duration in seconds
|
||||
|
||||
# Speaker information
|
||||
speakers: Optional[List[str]] = None # List of speakers found
|
||||
speaker_embeddings: Optional[Dict[str, List[float]]] = None
|
||||
|
||||
# Provider info
|
||||
provider: str = ""
|
||||
model: str = ""
|
||||
|
||||
# Raw response for debugging
|
||||
raw_response: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_storage_format(self) -> str:
|
||||
"""
|
||||
Convert to the JSON format used for storage in database.
|
||||
|
||||
Returns a JSON string in the format expected by the existing codebase:
|
||||
[
|
||||
{
|
||||
"speaker": "SPEAKER_00",
|
||||
"sentence": "Text here",
|
||||
"start_time": 0.0,
|
||||
"end_time": 5.5
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
if self.segments:
|
||||
return json.dumps([
|
||||
{
|
||||
'speaker': seg.speaker or 'Unknown Speaker',
|
||||
'sentence': seg.text,
|
||||
'start_time': seg.start_time,
|
||||
'end_time': seg.end_time
|
||||
}
|
||||
for seg in self.segments
|
||||
])
|
||||
# If no segments, return plain text (for non-diarized transcriptions)
|
||||
return self.text
|
||||
|
||||
def has_diarization(self) -> bool:
|
||||
"""Check if this response contains diarization data."""
|
||||
if not self.segments:
|
||||
return False
|
||||
return any(seg.speaker for seg in self.segments)
|
||||
|
||||
|
||||
class BaseTranscriptionConnector(ABC):
|
||||
"""Abstract base class for transcription connectors."""
|
||||
|
||||
# Class-level capability declarations - subclasses should override
|
||||
CAPABILITIES: Set[TranscriptionCapability] = set()
|
||||
PROVIDER_NAME: str = "unknown"
|
||||
|
||||
# Provider-specific constraints - subclasses should override
|
||||
SPECIFICATIONS: ConnectorSpecifications = DEFAULT_SPECIFICATIONS
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize connector with configuration.
|
||||
|
||||
Args:
|
||||
config: Provider-specific configuration dict
|
||||
"""
|
||||
self.config = config
|
||||
self._validate_config()
|
||||
|
||||
@abstractmethod
|
||||
def _validate_config(self) -> None:
|
||||
"""
|
||||
Validate required configuration is present.
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If required config is missing or invalid
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
|
||||
"""
|
||||
Perform transcription.
|
||||
|
||||
Args:
|
||||
request: Standardized transcription request
|
||||
|
||||
Returns:
|
||||
Standardized transcription response
|
||||
|
||||
Raises:
|
||||
TranscriptionError: On transcription failure
|
||||
ConfigurationError: On configuration issues
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports(self, capability: TranscriptionCapability) -> bool:
|
||||
"""Check if connector supports a capability."""
|
||||
return capability in self.CAPABILITIES
|
||||
|
||||
def get_capabilities(self) -> Set[TranscriptionCapability]:
|
||||
"""Get all supported capabilities."""
|
||||
return self.CAPABILITIES.copy()
|
||||
|
||||
@property
|
||||
def supports_diarization(self) -> bool:
|
||||
"""Check if connector supports speaker diarization."""
|
||||
return TranscriptionCapability.DIARIZATION in self.CAPABILITIES
|
||||
|
||||
@property
|
||||
def supports_chunking(self) -> bool:
|
||||
"""Check if connector supports automatic file chunking."""
|
||||
return TranscriptionCapability.CHUNKING in self.CAPABILITIES
|
||||
|
||||
@property
|
||||
def supports_speaker_count_control(self) -> bool:
|
||||
"""Check if connector supports min/max speaker count parameters."""
|
||||
return TranscriptionCapability.SPEAKER_COUNT_CONTROL in self.CAPABILITIES
|
||||
|
||||
@property
|
||||
def specifications(self) -> ConnectorSpecifications:
|
||||
"""Get connector specifications."""
|
||||
return self.SPECIFICATIONS
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> Dict[str, Any]:
|
||||
"""
|
||||
Return JSON schema for this connector's configuration.
|
||||
Useful for admin UI and validation.
|
||||
|
||||
Returns:
|
||||
JSON schema dict describing required and optional config
|
||||
"""
|
||||
return {}
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""
|
||||
Check if the connector is properly configured and reachable.
|
||||
|
||||
Returns:
|
||||
True if the connector is healthy, False otherwise
|
||||
"""
|
||||
return True
|
||||
15
src/services/transcription/connectors/__init__.py
Normal file
15
src/services/transcription/connectors/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Transcription connector implementations.
|
||||
"""
|
||||
|
||||
from .openai_whisper import OpenAIWhisperConnector
|
||||
from .openai_transcribe import OpenAITranscribeConnector
|
||||
from .asr_endpoint import ASREndpointConnector
|
||||
from .azure_openai_transcribe import AzureOpenAITranscribeConnector
|
||||
|
||||
__all__ = [
|
||||
'OpenAIWhisperConnector',
|
||||
'OpenAITranscribeConnector',
|
||||
'ASREndpointConnector',
|
||||
'AzureOpenAITranscribeConnector',
|
||||
]
|
||||
337
src/services/transcription/connectors/asr_endpoint.py
Normal file
337
src/services/transcription/connectors/asr_endpoint.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
ASR Endpoint connector for custom self-hosted ASR services.
|
||||
|
||||
Supports whisper-asr-webservice, WhisperX, and other compatible ASR services
|
||||
that expose a /asr endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import httpx
|
||||
from typing import Dict, Any, Set, Optional
|
||||
|
||||
from ..base import (
|
||||
BaseTranscriptionConnector,
|
||||
TranscriptionCapability,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionSegment,
|
||||
ConnectorSpecifications,
|
||||
)
|
||||
from ..exceptions import TranscriptionError, ConfigurationError, ProviderError
|
||||
from src.config.app_config import ASR_ENABLE_CHUNKING, ASR_MAX_DURATION_SECONDS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ASREndpointConnector(BaseTranscriptionConnector):
|
||||
"""Connector for custom ASR webservice (whisper-asr-webservice, WhisperX, etc.)."""
|
||||
|
||||
CAPABILITIES: Set[TranscriptionCapability] = {
|
||||
TranscriptionCapability.DIARIZATION,
|
||||
TranscriptionCapability.TIMESTAMPS,
|
||||
TranscriptionCapability.LANGUAGE_DETECTION,
|
||||
TranscriptionCapability.SPEAKER_COUNT_CONTROL, # Supports min/max speakers
|
||||
}
|
||||
PROVIDER_NAME = "asr_endpoint"
|
||||
|
||||
# SPECIFICATIONS is set dynamically in __init__ based on ASR_ENABLE_CHUNKING config
|
||||
# Default values here for class-level reference (overridden per-instance)
|
||||
SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=None,
|
||||
max_duration_seconds=None,
|
||||
handles_chunking_internally=True,
|
||||
)
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the ASR Endpoint connector.
|
||||
|
||||
Args:
|
||||
config: Configuration dict with keys:
|
||||
- base_url: ASR service base URL (required)
|
||||
- timeout: Request timeout in seconds (default: 1800)
|
||||
- return_speaker_embeddings: Whether to request embeddings (default: False)
|
||||
- diarize: Whether to enable diarization by default (default: True)
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.base_url = config['base_url'].rstrip('/')
|
||||
self._config_timeout = config.get('timeout', 1800) # 30 minutes default
|
||||
self.return_embeddings = config.get('return_speaker_embeddings', False)
|
||||
self.default_diarize = config.get('diarize', True)
|
||||
|
||||
# Configure chunking behavior based on environment variables
|
||||
# ASR_ENABLE_CHUNKING=true enables app-level chunking for self-hosted ASR services
|
||||
# that may crash on long files due to GPU memory exhaustion
|
||||
if ASR_ENABLE_CHUNKING:
|
||||
# Calculate recommended chunk size (80% of max for safety margin)
|
||||
recommended_chunk = int(ASR_MAX_DURATION_SECONDS * 0.8)
|
||||
self.SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=None, # No file size limit
|
||||
max_duration_seconds=ASR_MAX_DURATION_SECONDS,
|
||||
handles_chunking_internally=False, # App handles chunking
|
||||
recommended_chunk_seconds=recommended_chunk,
|
||||
)
|
||||
logger.info(
|
||||
f"ASR chunking enabled: max_duration={ASR_MAX_DURATION_SECONDS}s, "
|
||||
f"recommended_chunk={recommended_chunk}s"
|
||||
)
|
||||
else:
|
||||
# Default behavior: ASR service handles everything internally
|
||||
self.SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=None,
|
||||
max_duration_seconds=None,
|
||||
handles_chunking_internally=True,
|
||||
)
|
||||
|
||||
# Add speaker embeddings capability if enabled
|
||||
if self.return_embeddings:
|
||||
self.CAPABILITIES = self.CAPABILITIES | {TranscriptionCapability.SPEAKER_EMBEDDINGS}
|
||||
|
||||
@property
|
||||
def timeout(self):
|
||||
"""Get ASR timeout, reading fresh from env/DB each time to respect runtime changes."""
|
||||
# Environment variables take priority
|
||||
env_timeout = os.environ.get('ASR_TIMEOUT') or os.environ.get('asr_timeout_seconds')
|
||||
if env_timeout:
|
||||
try:
|
||||
return int(env_timeout)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try database setting (Admin UI)
|
||||
try:
|
||||
from src.models import SystemSetting
|
||||
db_timeout = SystemSetting.get_setting('asr_timeout_seconds', None)
|
||||
if db_timeout is not None:
|
||||
return int(db_timeout)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to config value from initialization
|
||||
return self._config_timeout
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""Validate required configuration."""
|
||||
if not self.config.get('base_url'):
|
||||
raise ConfigurationError("base_url is required for ASR endpoint connector")
|
||||
|
||||
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
|
||||
"""
|
||||
Transcribe audio using ASR webservice.
|
||||
|
||||
Args:
|
||||
request: Standardized transcription request
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse with segments and speaker information
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/asr"
|
||||
|
||||
params = {
|
||||
'encode': True,
|
||||
'task': 'transcribe',
|
||||
'output': 'json'
|
||||
}
|
||||
|
||||
if request.language:
|
||||
params['language'] = request.language
|
||||
logger.info(f"Using transcription language: {request.language}")
|
||||
|
||||
# Determine if we should diarize
|
||||
should_diarize = request.diarize if request.diarize is not None else self.default_diarize
|
||||
|
||||
# Send both parameter names for compatibility:
|
||||
# - 'diarize' is used by whisper-asr-webservice
|
||||
# - 'enable_diarization' is used by WhisperX
|
||||
params['diarize'] = should_diarize
|
||||
params['enable_diarization'] = should_diarize
|
||||
|
||||
if should_diarize and self.return_embeddings:
|
||||
params['return_speaker_embeddings'] = True
|
||||
|
||||
if request.min_speakers:
|
||||
params['min_speakers'] = request.min_speakers
|
||||
if request.max_speakers:
|
||||
params['max_speakers'] = request.max_speakers
|
||||
|
||||
if request.prompt:
|
||||
params['initial_prompt'] = request.prompt
|
||||
if request.hotwords:
|
||||
params['hotwords'] = request.hotwords
|
||||
|
||||
content_type = request.mime_type or 'application/octet-stream'
|
||||
files = {
|
||||
'audio_file': (request.filename, request.audio_file, content_type)
|
||||
}
|
||||
|
||||
# Configure timeout: generous values for large file uploads
|
||||
# Write timeout needs to be high too - large files take time to upload
|
||||
timeout = httpx.Timeout(
|
||||
None,
|
||||
connect=60.0,
|
||||
read=float(self.timeout),
|
||||
write=float(self.timeout),
|
||||
pool=None
|
||||
)
|
||||
|
||||
logger.info(f"Sending ASR request to {url} with params: {params} (timeout: {self.timeout}s)")
|
||||
|
||||
with httpx.Client() as client:
|
||||
response = client.post(url, params=params, files=files, timeout=timeout)
|
||||
logger.info(f"ASR request completed with status: {response.status_code}")
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the JSON response
|
||||
response_text = response.text
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as json_err:
|
||||
if response_text.strip().startswith('<'):
|
||||
logger.error(f"ASR returned HTML error page (status {response.status_code})")
|
||||
raise ProviderError(
|
||||
f"ASR service returned HTML error page",
|
||||
provider=self.PROVIDER_NAME,
|
||||
status_code=response.status_code
|
||||
)
|
||||
else:
|
||||
raise ProviderError(
|
||||
f"ASR service returned invalid response: {json_err}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
return self._parse_response(data)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"ASR request failed with status {e.response.status_code}")
|
||||
raise ProviderError(
|
||||
f"ASR request failed with status {e.response.status_code}",
|
||||
provider=self.PROVIDER_NAME,
|
||||
status_code=e.response.status_code
|
||||
) from e
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"ASR request timed out after {self.timeout}s")
|
||||
raise TranscriptionError(f"ASR request timed out after {self.timeout}s") from e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"ASR transcription failed: {error_msg}")
|
||||
raise TranscriptionError(f"ASR transcription failed: {error_msg}") from e
|
||||
|
||||
def _parse_response(self, data: Dict[str, Any]) -> TranscriptionResponse:
|
||||
"""
|
||||
Parse ASR webservice response into standardized format.
|
||||
|
||||
The ASR response contains:
|
||||
- text: Full transcription text
|
||||
- language: Detected language
|
||||
- segments: Array of segments with speaker, text, start, end
|
||||
- speaker_embeddings: Optional speaker embeddings (WhisperX only)
|
||||
"""
|
||||
segments = []
|
||||
speakers = set()
|
||||
full_text_parts = []
|
||||
last_speaker = None
|
||||
|
||||
logger.info(f"ASR response keys: {list(data.keys())}")
|
||||
|
||||
if 'segments' in data and isinstance(data['segments'], list):
|
||||
logger.info(f"Number of segments: {len(data['segments'])}")
|
||||
|
||||
for seg in data['segments']:
|
||||
speaker = seg.get('speaker')
|
||||
|
||||
# Handle missing speakers by carrying forward from previous segment
|
||||
if speaker is None:
|
||||
if last_speaker is not None:
|
||||
speaker = last_speaker
|
||||
else:
|
||||
speaker = 'UNKNOWN_SPEAKER'
|
||||
else:
|
||||
last_speaker = speaker
|
||||
|
||||
text = seg.get('text', '').strip()
|
||||
speakers.add(speaker)
|
||||
full_text_parts.append(f"[{speaker}]: {text}")
|
||||
|
||||
segments.append(TranscriptionSegment(
|
||||
text=text,
|
||||
speaker=speaker,
|
||||
start_time=seg.get('start'),
|
||||
end_time=seg.get('end')
|
||||
))
|
||||
|
||||
# Get the full text
|
||||
if 'text' in data and isinstance(data['text'], str):
|
||||
full_text = data['text']
|
||||
elif full_text_parts:
|
||||
full_text = '\n'.join(full_text_parts)
|
||||
else:
|
||||
full_text = ''
|
||||
|
||||
# Extract speaker embeddings if present
|
||||
speaker_embeddings = data.get('speaker_embeddings')
|
||||
if speaker_embeddings:
|
||||
logger.info(f"Received speaker embeddings for speakers: {list(speaker_embeddings.keys())}")
|
||||
|
||||
logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}")
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=full_text,
|
||||
segments=segments,
|
||||
speakers=sorted(list(speakers)),
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
language=data.get('language'),
|
||||
provider=self.PROVIDER_NAME,
|
||||
model="asr-endpoint",
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check if ASR endpoint is reachable."""
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
# Try common health check endpoints
|
||||
for endpoint in ['/health', '/']:
|
||||
try:
|
||||
response = client.get(f"{self.base_url}{endpoint}")
|
||||
if response.status_code < 500:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> Dict[str, Any]:
|
||||
"""Return JSON schema for configuration."""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": ["base_url"],
|
||||
"properties": {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "ASR service base URL (e.g., http://whisper-asr:9000)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"default": 1800,
|
||||
"description": "Request timeout in seconds"
|
||||
},
|
||||
"diarize": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Enable speaker diarization by default"
|
||||
},
|
||||
"return_speaker_embeddings": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Request speaker embeddings (WhisperX only)"
|
||||
}
|
||||
}
|
||||
}
|
||||
370
src/services/transcription/connectors/azure_openai_transcribe.py
Normal file
370
src/services/transcription/connectors/azure_openai_transcribe.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Azure OpenAI Transcribe connector.
|
||||
|
||||
Supports Azure OpenAI audio transcription models:
|
||||
- whisper-1: Basic transcription (no diarization)
|
||||
- gpt-4o-transcribe: High quality transcription
|
||||
- gpt-4o-mini-transcribe: Cost-effective transcription
|
||||
- gpt-4o-transcribe-diarize: Speaker diarization with labels A, B, C, D
|
||||
|
||||
Azure OpenAI uses a different API format than standard OpenAI:
|
||||
- Endpoint: https://{resource}.openai.azure.com/openai/deployments/{deployment}/audio/transcriptions
|
||||
- Requires api-version query parameter
|
||||
- Uses api-key header for authentication
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Dict, Any, Set, Optional
|
||||
|
||||
from ..base import (
|
||||
BaseTranscriptionConnector,
|
||||
TranscriptionCapability,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionSegment,
|
||||
ConnectorSpecifications,
|
||||
)
|
||||
from ..exceptions import TranscriptionError, ConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAITranscribeConnector(BaseTranscriptionConnector):
|
||||
"""Connector for Azure OpenAI audio transcription models."""
|
||||
|
||||
# Base capabilities - diarization added dynamically based on model
|
||||
CAPABILITIES: Set[TranscriptionCapability] = {
|
||||
TranscriptionCapability.TIMESTAMPS,
|
||||
TranscriptionCapability.LANGUAGE_DETECTION,
|
||||
}
|
||||
PROVIDER_NAME = "azure_openai_transcribe"
|
||||
|
||||
# Default specifications (will be overridden per-model in __init__)
|
||||
SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=25 * 1024 * 1024, # 25MB
|
||||
max_duration_seconds=1400, # Default to most restrictive (diarize model)
|
||||
min_duration_for_chunking=30,
|
||||
handles_chunking_internally=False,
|
||||
requires_chunking_param=True,
|
||||
recommended_chunk_seconds=1200,
|
||||
unsupported_codecs=frozenset({'opus'}),
|
||||
)
|
||||
|
||||
# Models and their capabilities
|
||||
MODELS = {
|
||||
'whisper-1': {
|
||||
'supports_diarization': False,
|
||||
'max_duration_seconds': 1500,
|
||||
'recommended_chunk_seconds': 1200,
|
||||
'description': 'OpenAI Whisper model on Azure'
|
||||
},
|
||||
'gpt-4o-transcribe': {
|
||||
'supports_diarization': False,
|
||||
'max_duration_seconds': 1500,
|
||||
'recommended_chunk_seconds': 1200,
|
||||
'description': 'High quality transcription'
|
||||
},
|
||||
'gpt-4o-mini-transcribe': {
|
||||
'supports_diarization': False,
|
||||
'max_duration_seconds': 1500,
|
||||
'recommended_chunk_seconds': 1200,
|
||||
'description': 'Cost-effective transcription'
|
||||
},
|
||||
'gpt-4o-mini-transcribe-2025-12-15': {
|
||||
'supports_diarization': False,
|
||||
'max_duration_seconds': 1500,
|
||||
'recommended_chunk_seconds': 1200,
|
||||
'description': 'Cost-effective transcription (dated version)'
|
||||
},
|
||||
'gpt-4o-transcribe-diarize': {
|
||||
'supports_diarization': True,
|
||||
'max_duration_seconds': 1400,
|
||||
'recommended_chunk_seconds': 1200,
|
||||
'description': 'Speaker diarization with labels A, B, C, D'
|
||||
}
|
||||
}
|
||||
|
||||
# Default API version - can be overridden in config
|
||||
DEFAULT_API_VERSION = "2025-04-01-preview"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the Azure OpenAI Transcribe connector.
|
||||
|
||||
Args:
|
||||
config: Configuration dict with keys:
|
||||
- api_key: Azure OpenAI API key (required)
|
||||
- endpoint: Azure OpenAI endpoint URL (required)
|
||||
e.g., https://your-resource.openai.azure.com
|
||||
- deployment_name: The deployment name for the model (required)
|
||||
- api_version: API version (default: 2025-04-01-preview)
|
||||
- model: Model name for validation (optional, defaults to deployment_name)
|
||||
"""
|
||||
# Store model/deployment before calling super().__init__
|
||||
self.deployment_name = config.get('deployment_name', '')
|
||||
self.model = config.get('model', self.deployment_name)
|
||||
self.api_version = config.get('api_version', self.DEFAULT_API_VERSION)
|
||||
|
||||
# Set model-specific specifications
|
||||
model_info = self.MODELS.get(self.model, {})
|
||||
if model_info:
|
||||
self.SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=25 * 1024 * 1024,
|
||||
max_duration_seconds=model_info.get('max_duration_seconds', 1400),
|
||||
min_duration_for_chunking=30,
|
||||
handles_chunking_internally=False,
|
||||
requires_chunking_param=True,
|
||||
recommended_chunk_seconds=model_info.get('recommended_chunk_seconds', 1200),
|
||||
unsupported_codecs=frozenset({'opus'}),
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
# Parse endpoint URL
|
||||
self.endpoint = config['endpoint'].rstrip('/')
|
||||
|
||||
# Set up HTTP client
|
||||
self.http_client = httpx.Client(
|
||||
timeout=httpx.Timeout(
|
||||
connect=60.0,
|
||||
read=1800.0, # 30 minutes for long transcriptions
|
||||
write=1800.0,
|
||||
pool=None
|
||||
),
|
||||
headers={
|
||||
"api-key": config['api_key'],
|
||||
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
|
||||
}
|
||||
)
|
||||
|
||||
# Dynamically update capabilities based on model
|
||||
if self._model_supports_diarization():
|
||||
self.CAPABILITIES = self.CAPABILITIES | {
|
||||
TranscriptionCapability.DIARIZATION,
|
||||
TranscriptionCapability.KNOWN_SPEAKERS
|
||||
}
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""Validate required configuration."""
|
||||
if not self.config.get('api_key'):
|
||||
raise ConfigurationError("api_key is required for Azure OpenAI Transcribe connector")
|
||||
if not self.config.get('endpoint'):
|
||||
raise ConfigurationError("endpoint is required for Azure OpenAI Transcribe connector")
|
||||
if not self.config.get('deployment_name'):
|
||||
raise ConfigurationError("deployment_name is required for Azure OpenAI Transcribe connector")
|
||||
|
||||
def _model_supports_diarization(self) -> bool:
|
||||
"""Check if the current model supports diarization."""
|
||||
model_info = self.MODELS.get(self.model, {})
|
||||
return model_info.get('supports_diarization', False)
|
||||
|
||||
def _build_url(self) -> str:
|
||||
"""Build the Azure OpenAI transcription API URL."""
|
||||
return f"{self.endpoint}/openai/deployments/{self.deployment_name}/audio/transcriptions?api-version={self.api_version}"
|
||||
|
||||
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
|
||||
"""
|
||||
Transcribe audio using Azure OpenAI API.
|
||||
|
||||
Args:
|
||||
request: Standardized transcription request
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse, with segments if using diarization model
|
||||
"""
|
||||
try:
|
||||
url = self._build_url()
|
||||
|
||||
# Build form data
|
||||
data = {}
|
||||
|
||||
if request.language:
|
||||
data["language"] = request.language
|
||||
logger.info(f"Using transcription language: {request.language}")
|
||||
|
||||
# Handle diarization model specifics
|
||||
is_diarize_model = 'diarize' in self.model.lower()
|
||||
|
||||
if is_diarize_model:
|
||||
# Required: chunking_strategy for audio > 30 seconds
|
||||
data["chunking_strategy"] = "auto"
|
||||
|
||||
if request.diarize:
|
||||
data["response_format"] = "diarized_json"
|
||||
logger.info("Using diarized_json response format for speaker diarization")
|
||||
|
||||
# Known speaker support
|
||||
if request.known_speaker_names and request.known_speaker_references:
|
||||
for i, name in enumerate(request.known_speaker_names):
|
||||
if name in request.known_speaker_references:
|
||||
data[f"known_speaker_names[{i}]"] = name
|
||||
data[f"known_speaker_references[{i}]"] = request.known_speaker_references[name]
|
||||
logger.info(f"Using known speaker references for {len(request.known_speaker_names)} speakers")
|
||||
else:
|
||||
# Non-diarization models - request verbose_json for timestamps
|
||||
data["response_format"] = "verbose_json"
|
||||
# Combine initial prompt and hotwords into a single prompt
|
||||
prompt_parts = []
|
||||
if request.prompt:
|
||||
prompt_parts.append(request.prompt)
|
||||
if request.hotwords:
|
||||
prompt_parts.append(request.hotwords)
|
||||
if prompt_parts:
|
||||
data["prompt"] = ". ".join(prompt_parts)
|
||||
|
||||
# Prepare file for upload
|
||||
content_type = request.mime_type or 'application/octet-stream'
|
||||
files = {
|
||||
"file": (request.filename, request.audio_file, content_type)
|
||||
}
|
||||
|
||||
logger.info(f"Sending request to Azure OpenAI: {url}")
|
||||
logger.info(f"Model: {self.model}, Deployment: {self.deployment_name}")
|
||||
|
||||
response = self.http_client.post(url, data=data, files=files)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = response.text
|
||||
try:
|
||||
error_json = response.json()
|
||||
if 'error' in error_json:
|
||||
error_detail = error_json['error'].get('message', error_detail)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"Azure OpenAI transcription failed: {response.status_code} - {error_detail}")
|
||||
raise TranscriptionError(f"Azure OpenAI transcription failed: {response.status_code} - {error_detail}")
|
||||
|
||||
result = response.json()
|
||||
|
||||
# Parse response based on format
|
||||
if is_diarize_model and request.diarize:
|
||||
return self._parse_diarized_response(result)
|
||||
else:
|
||||
return self._parse_response(result)
|
||||
|
||||
except TranscriptionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Azure OpenAI transcription failed: {error_msg}")
|
||||
raise TranscriptionError(f"Azure OpenAI transcription failed: {error_msg}") from e
|
||||
|
||||
def _parse_response(self, response: Dict) -> TranscriptionResponse:
|
||||
"""Parse a standard (non-diarized) response."""
|
||||
text = response.get('text', '')
|
||||
|
||||
# Check for segments (verbose_json format)
|
||||
segments = []
|
||||
if 'segments' in response:
|
||||
for seg in response['segments']:
|
||||
segments.append(TranscriptionSegment(
|
||||
text=seg.get('text', ''),
|
||||
start_time=seg.get('start'),
|
||||
end_time=seg.get('end')
|
||||
))
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=text,
|
||||
segments=segments if segments else None,
|
||||
language=response.get('language'),
|
||||
provider=self.PROVIDER_NAME,
|
||||
model=self.model,
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
def _parse_diarized_response(self, response: Dict) -> TranscriptionResponse:
|
||||
"""
|
||||
Parse diarized JSON response into standardized format.
|
||||
|
||||
The diarized_json response contains segments with:
|
||||
- speaker: "A", "B", "C", "D" etc.
|
||||
- text: The transcribed text
|
||||
- start: Segment start time
|
||||
- end: Segment end time
|
||||
"""
|
||||
segments = []
|
||||
speakers = set()
|
||||
full_text_parts = []
|
||||
|
||||
raw_segments = response.get('segments', [])
|
||||
|
||||
if not raw_segments:
|
||||
# Fallback to text-only response
|
||||
logger.warning("No segments found in diarized response, falling back to text")
|
||||
return self._parse_response(response)
|
||||
|
||||
for seg in raw_segments:
|
||||
speaker = seg.get('speaker', 'Unknown')
|
||||
text = seg.get('text', '')
|
||||
start = seg.get('start')
|
||||
end = seg.get('end')
|
||||
|
||||
# Skip empty segments
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
speakers.add(speaker)
|
||||
full_text_parts.append(f"[{speaker}]: {text}")
|
||||
|
||||
segments.append(TranscriptionSegment(
|
||||
text=text,
|
||||
speaker=speaker,
|
||||
start_time=start,
|
||||
end_time=end
|
||||
))
|
||||
|
||||
# Build full text with speaker labels
|
||||
full_text = '\n'.join(full_text_parts)
|
||||
|
||||
logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}")
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=full_text,
|
||||
segments=segments,
|
||||
speakers=sorted(list(speakers)),
|
||||
language=response.get('language'),
|
||||
provider=self.PROVIDER_NAME,
|
||||
model=self.model,
|
||||
raw_response=response
|
||||
)
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check if the connector is properly configured."""
|
||||
return bool(
|
||||
self.config.get('api_key') and
|
||||
self.config.get('endpoint') and
|
||||
self.config.get('deployment_name')
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> Dict[str, Any]:
|
||||
"""Return JSON schema for configuration."""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": ["api_key", "endpoint", "deployment_name"],
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "Azure OpenAI API key"
|
||||
},
|
||||
"endpoint": {
|
||||
"type": "string",
|
||||
"description": "Azure OpenAI endpoint URL (e.g., https://your-resource.openai.azure.com)"
|
||||
},
|
||||
"deployment_name": {
|
||||
"type": "string",
|
||||
"description": "The deployment name for your transcription model"
|
||||
},
|
||||
"api_version": {
|
||||
"type": "string",
|
||||
"default": cls.DEFAULT_API_VERSION,
|
||||
"description": "Azure OpenAI API version"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": list(cls.MODELS.keys()),
|
||||
"description": "Model type (for capability detection, defaults to deployment_name)"
|
||||
}
|
||||
}
|
||||
}
|
||||
329
src/services/transcription/connectors/openai_transcribe.py
Normal file
329
src/services/transcription/connectors/openai_transcribe.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
OpenAI GPT-4o Transcribe connector.
|
||||
|
||||
Supports the newer GPT-4o based transcription models:
|
||||
- gpt-4o-transcribe: High quality transcription
|
||||
- gpt-4o-mini-transcribe: Cost-effective transcription
|
||||
- gpt-4o-transcribe-diarize: Speaker diarization with labels A, B, C, D
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
from typing import Dict, Any, Set, Optional
|
||||
|
||||
from ..base import (
|
||||
BaseTranscriptionConnector,
|
||||
TranscriptionCapability,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionSegment,
|
||||
ConnectorSpecifications,
|
||||
)
|
||||
from ..exceptions import TranscriptionError, ConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAITranscribeConnector(BaseTranscriptionConnector):
|
||||
"""Connector for GPT-4o Transcribe models with optional diarization support."""
|
||||
|
||||
# Base capabilities - diarization added dynamically based on model
|
||||
CAPABILITIES: Set[TranscriptionCapability] = {
|
||||
TranscriptionCapability.TIMESTAMPS,
|
||||
TranscriptionCapability.LANGUAGE_DETECTION,
|
||||
}
|
||||
PROVIDER_NAME = "openai_transcribe"
|
||||
|
||||
# GPT-4o Transcribe models have specific constraints
|
||||
# - 25MB file size limit (all models)
|
||||
# - Duration limits vary by model:
|
||||
# - gpt-4o-transcribe / gpt-4o-mini-transcribe: 1500 seconds (25 min)
|
||||
# - gpt-4o-transcribe-diarize: 1400 seconds (~23 min)
|
||||
# - chunking_strategy="auto" handles files internally up to the duration limit
|
||||
# Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, flac, ogg, oga
|
||||
# NOT supported: opus (used by WhatsApp voice notes, Discord)
|
||||
|
||||
# Default specifications (will be overridden per-model in __init__)
|
||||
SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=25 * 1024 * 1024, # 25MB
|
||||
max_duration_seconds=1400, # Default to most restrictive (diarize model)
|
||||
min_duration_for_chunking=30, # >30s needs chunking_strategy param
|
||||
handles_chunking_internally=False, # App must chunk files > max_duration_seconds
|
||||
requires_chunking_param=True, # Must send chunking_strategy for >30s
|
||||
recommended_chunk_seconds=1200, # 20 minutes - safe margin
|
||||
unsupported_codecs=frozenset({'opus'}), # OpenAI API doesn't support opus
|
||||
)
|
||||
|
||||
# Models and their capabilities with duration limits
|
||||
MODELS = {
|
||||
'gpt-4o-transcribe': {
|
||||
'supports_diarization': False,
|
||||
'max_duration_seconds': 1500, # 25 minutes
|
||||
'recommended_chunk_seconds': 1200, # 20 minutes
|
||||
'description': 'High quality transcription'
|
||||
},
|
||||
'gpt-4o-mini-transcribe': {
|
||||
'supports_diarization': False,
|
||||
'max_duration_seconds': 1500, # 25 minutes
|
||||
'recommended_chunk_seconds': 1200, # 20 minutes
|
||||
'description': 'Cost-effective transcription'
|
||||
},
|
||||
'gpt-4o-mini-transcribe-2025-12-15': {
|
||||
'supports_diarization': False,
|
||||
'max_duration_seconds': 1500, # 25 minutes
|
||||
'recommended_chunk_seconds': 1200, # 20 minutes
|
||||
'description': 'Cost-effective transcription (dated version)'
|
||||
},
|
||||
'gpt-4o-transcribe-diarize': {
|
||||
'supports_diarization': True,
|
||||
'max_duration_seconds': 1400, # ~23 minutes (more restrictive)
|
||||
'recommended_chunk_seconds': 1200, # 20 minutes
|
||||
'description': 'Speaker diarization with labels A, B, C, D'
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the GPT-4o Transcribe connector.
|
||||
|
||||
Args:
|
||||
config: Configuration dict with keys:
|
||||
- api_key: OpenAI API key (required)
|
||||
- base_url: API base URL (default: https://api.openai.com/v1)
|
||||
- model: Model name (required, one of MODELS)
|
||||
- http_client: Optional httpx.Client instance
|
||||
"""
|
||||
# Store model before calling super().__init__ since _validate_config needs it
|
||||
self.model = config.get('model', 'gpt-4o-transcribe')
|
||||
|
||||
# Set model-specific specifications (override class defaults)
|
||||
# Use SPECIFICATIONS (uppercase) to shadow the class attribute
|
||||
model_info = self.MODELS.get(self.model, {})
|
||||
self.SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=25 * 1024 * 1024, # 25MB (same for all)
|
||||
max_duration_seconds=model_info.get('max_duration_seconds', 1400),
|
||||
min_duration_for_chunking=30,
|
||||
handles_chunking_internally=False,
|
||||
requires_chunking_param=True,
|
||||
recommended_chunk_seconds=model_info.get('recommended_chunk_seconds', 1200),
|
||||
unsupported_codecs=frozenset({'opus'}),
|
||||
)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
# Set up HTTP client with custom headers
|
||||
http_client = config.get('http_client')
|
||||
if not http_client:
|
||||
app_headers = {
|
||||
"HTTP-Referer": "https://github.com/murtaza-nasir/speakr",
|
||||
"X-Title": "Speakr - AI Audio Transcription",
|
||||
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
|
||||
}
|
||||
http_client = httpx.Client(verify=True, headers=app_headers)
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=config['api_key'],
|
||||
base_url=config.get('base_url', 'https://api.openai.com/v1'),
|
||||
http_client=http_client
|
||||
)
|
||||
|
||||
# Dynamically update capabilities based on model
|
||||
if self._model_supports_diarization():
|
||||
self.CAPABILITIES = self.CAPABILITIES | {
|
||||
TranscriptionCapability.DIARIZATION,
|
||||
TranscriptionCapability.KNOWN_SPEAKERS
|
||||
}
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""Validate required configuration."""
|
||||
if not self.config.get('api_key'):
|
||||
raise ConfigurationError("api_key is required for OpenAI Transcribe connector")
|
||||
|
||||
model = self.config.get('model', 'gpt-4o-transcribe')
|
||||
if model not in self.MODELS:
|
||||
raise ConfigurationError(
|
||||
f"Unknown model: {model}. Valid models: {list(self.MODELS.keys())}"
|
||||
)
|
||||
|
||||
def _model_supports_diarization(self) -> bool:
|
||||
"""Check if the current model supports diarization."""
|
||||
model_info = self.MODELS.get(self.model, {})
|
||||
return model_info.get('supports_diarization', False)
|
||||
|
||||
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
|
||||
"""
|
||||
Transcribe audio using GPT-4o Transcribe API.
|
||||
|
||||
Args:
|
||||
request: Standardized transcription request
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse, with segments if using diarization model
|
||||
"""
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"file": request.audio_file,
|
||||
}
|
||||
|
||||
if request.language:
|
||||
params["language"] = request.language
|
||||
logger.info(f"Using transcription language: {request.language}")
|
||||
|
||||
# Handle diarization model specifics
|
||||
if self.model == 'gpt-4o-transcribe-diarize':
|
||||
# Required: chunking_strategy for audio > 30 seconds
|
||||
params["chunking_strategy"] = "auto"
|
||||
|
||||
if request.diarize:
|
||||
params["response_format"] = "diarized_json"
|
||||
logger.info("Using diarized_json response format for speaker diarization")
|
||||
|
||||
# Known speaker support for maintaining speaker identity across chunks
|
||||
# known_speaker_names is a list of speaker labels (e.g., ["A", "B"])
|
||||
# known_speaker_references is a dict mapping label to data URL
|
||||
if request.known_speaker_names and request.known_speaker_references:
|
||||
# OpenAI expects lists for both parameters
|
||||
speaker_names = []
|
||||
speaker_refs = []
|
||||
|
||||
for name in request.known_speaker_names:
|
||||
if name in request.known_speaker_references:
|
||||
speaker_names.append(name)
|
||||
speaker_refs.append(request.known_speaker_references[name])
|
||||
|
||||
if speaker_names:
|
||||
# Use extra_body to pass the known speaker parameters
|
||||
params["extra_body"] = {
|
||||
"known_speaker_names": speaker_names,
|
||||
"known_speaker_references": speaker_refs
|
||||
}
|
||||
logger.info(f"Using known speaker references for {len(speaker_names)} speakers: {speaker_names}")
|
||||
else:
|
||||
# Non-diarization models - combine initial prompt and hotwords
|
||||
prompt_parts = []
|
||||
if request.prompt:
|
||||
prompt_parts.append(request.prompt)
|
||||
if request.hotwords:
|
||||
prompt_parts.append(request.hotwords)
|
||||
if prompt_parts:
|
||||
params["prompt"] = ". ".join(prompt_parts)
|
||||
|
||||
logger.info(f"Sending request to GPT-4o Transcribe API with model: {self.model}")
|
||||
response = self.client.audio.transcriptions.create(**params)
|
||||
|
||||
# Parse response based on format
|
||||
if self.model == 'gpt-4o-transcribe-diarize' and request.diarize:
|
||||
return self._parse_diarized_response(response)
|
||||
else:
|
||||
return self._parse_text_response(response)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"GPT-4o transcription failed: {error_msg}")
|
||||
raise TranscriptionError(f"GPT-4o transcription failed: {error_msg}") from e
|
||||
|
||||
def _parse_text_response(self, response) -> TranscriptionResponse:
|
||||
"""Parse a plain text response."""
|
||||
text = response.text if hasattr(response, 'text') else str(response)
|
||||
return TranscriptionResponse(
|
||||
text=text,
|
||||
provider=self.PROVIDER_NAME,
|
||||
model=self.model
|
||||
)
|
||||
|
||||
def _parse_diarized_response(self, response) -> TranscriptionResponse:
|
||||
"""
|
||||
Parse diarized JSON response into standardized format.
|
||||
|
||||
The diarized_json response contains segments with:
|
||||
- speaker: "A", "B", "C", "D" etc.
|
||||
- text: The transcribed text
|
||||
- start: Segment start time
|
||||
- end: Segment end time
|
||||
"""
|
||||
segments = []
|
||||
speakers = set()
|
||||
full_text_parts = []
|
||||
|
||||
# Handle response object - could be dict or object with attributes
|
||||
if hasattr(response, 'segments'):
|
||||
raw_segments = response.segments
|
||||
elif isinstance(response, dict) and 'segments' in response:
|
||||
raw_segments = response['segments']
|
||||
else:
|
||||
# Fallback to text-only response
|
||||
logger.warning("No segments found in diarized response, falling back to text")
|
||||
return self._parse_text_response(response)
|
||||
|
||||
for seg in raw_segments:
|
||||
# Handle both dict and object segments
|
||||
if isinstance(seg, dict):
|
||||
speaker = seg.get('speaker', 'Unknown')
|
||||
text = seg.get('text', '')
|
||||
start = seg.get('start')
|
||||
end = seg.get('end')
|
||||
else:
|
||||
speaker = getattr(seg, 'speaker', 'Unknown')
|
||||
text = getattr(seg, 'text', '')
|
||||
start = getattr(seg, 'start', None)
|
||||
end = getattr(seg, 'end', None)
|
||||
|
||||
# Skip empty segments
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
speakers.add(speaker)
|
||||
full_text_parts.append(f"[{speaker}]: {text}")
|
||||
|
||||
segments.append(TranscriptionSegment(
|
||||
text=text,
|
||||
speaker=speaker,
|
||||
start_time=start,
|
||||
end_time=end
|
||||
))
|
||||
|
||||
# Always use our formatted text with speaker labels for diarized responses
|
||||
# OpenAI's response.text is plain text WITHOUT speaker labels
|
||||
full_text = '\n'.join(full_text_parts)
|
||||
|
||||
logger.info(f"Parsed {len(segments)} segments with {len(speakers)} unique speakers: {sorted(speakers)}")
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=full_text,
|
||||
segments=segments,
|
||||
speakers=sorted(list(speakers)),
|
||||
provider=self.PROVIDER_NAME,
|
||||
model=self.model,
|
||||
raw_response=response if isinstance(response, dict) else None
|
||||
)
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check if the connector is properly configured."""
|
||||
return bool(self.config.get('api_key'))
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> Dict[str, Any]:
|
||||
"""Return JSON schema for configuration."""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": ["api_key"],
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "OpenAI API key"
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"default": "https://api.openai.com/v1",
|
||||
"description": "API base URL"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": list(cls.MODELS.keys()),
|
||||
"default": "gpt-4o-transcribe",
|
||||
"description": "GPT-4o transcription model to use"
|
||||
}
|
||||
}
|
||||
}
|
||||
153
src/services/transcription/connectors/openai_whisper.py
Normal file
153
src/services/transcription/connectors/openai_whisper.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
OpenAI Whisper API connector (whisper-1 model).
|
||||
|
||||
This is the legacy Whisper API connector that supports the whisper-1 model.
|
||||
It returns plain text transcriptions without speaker diarization.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
from typing import Dict, Any, Set
|
||||
|
||||
from ..base import (
|
||||
BaseTranscriptionConnector,
|
||||
TranscriptionCapability,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
ConnectorSpecifications,
|
||||
)
|
||||
from ..exceptions import TranscriptionError, ConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIWhisperConnector(BaseTranscriptionConnector):
|
||||
"""Connector for OpenAI Whisper API (whisper-1 model)."""
|
||||
|
||||
CAPABILITIES: Set[TranscriptionCapability] = {
|
||||
TranscriptionCapability.CHUNKING,
|
||||
TranscriptionCapability.TIMESTAMPS,
|
||||
TranscriptionCapability.LANGUAGE_DETECTION,
|
||||
}
|
||||
PROVIDER_NAME = "openai_whisper"
|
||||
|
||||
# OpenAI Whisper has a 25MB file limit and doesn't handle chunking internally
|
||||
# Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, flac, ogg, oga
|
||||
# NOT supported: opus (used by WhatsApp voice notes, Discord)
|
||||
SPECIFICATIONS = ConnectorSpecifications(
|
||||
max_file_size_bytes=25 * 1024 * 1024, # 25MB
|
||||
handles_chunking_internally=False,
|
||||
recommended_chunk_seconds=600, # 10 minutes
|
||||
unsupported_codecs=frozenset({'opus'}), # OpenAI API doesn't support opus
|
||||
)
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the Whisper connector.
|
||||
|
||||
Args:
|
||||
config: Configuration dict with keys:
|
||||
- api_key: OpenAI API key (required)
|
||||
- base_url: API base URL (optional)
|
||||
- model: Model name (default: whisper-1)
|
||||
- http_client: Optional httpx.Client instance
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Set up HTTP client with custom headers
|
||||
http_client = config.get('http_client')
|
||||
if not http_client:
|
||||
app_headers = {
|
||||
"HTTP-Referer": "https://github.com/murtaza-nasir/speakr",
|
||||
"X-Title": "Speakr - AI Audio Transcription",
|
||||
"User-Agent": "Speakr/1.0 (https://github.com/murtaza-nasir/speakr)"
|
||||
}
|
||||
http_client = httpx.Client(verify=True, headers=app_headers)
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=config['api_key'],
|
||||
base_url=config.get('base_url') or None,
|
||||
http_client=http_client
|
||||
)
|
||||
self.model = config.get('model', 'whisper-1')
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""Validate required configuration."""
|
||||
if not self.config.get('api_key'):
|
||||
raise ConfigurationError("api_key is required for OpenAI Whisper connector")
|
||||
|
||||
def transcribe(self, request: TranscriptionRequest) -> TranscriptionResponse:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper API.
|
||||
|
||||
Args:
|
||||
request: Standardized transcription request
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse with plain text (no diarization)
|
||||
"""
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"file": request.audio_file,
|
||||
}
|
||||
|
||||
if request.language:
|
||||
params["language"] = request.language
|
||||
logger.info(f"Using transcription language: {request.language}")
|
||||
|
||||
# Combine initial prompt and hotwords into a single prompt
|
||||
# OpenAI Whisper uses prompt for both steering and vocabulary hints
|
||||
prompt_parts = []
|
||||
if request.prompt:
|
||||
prompt_parts.append(request.prompt)
|
||||
if request.hotwords:
|
||||
prompt_parts.append(request.hotwords)
|
||||
if prompt_parts:
|
||||
params["prompt"] = ". ".join(prompt_parts)
|
||||
|
||||
if request.temperature is not None:
|
||||
params["temperature"] = request.temperature
|
||||
|
||||
logger.info(f"Sending request to Whisper API with model: {self.model}")
|
||||
transcript = self.client.audio.transcriptions.create(**params)
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=transcript.text,
|
||||
provider=self.PROVIDER_NAME,
|
||||
model=self.model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Whisper transcription failed: {error_msg}")
|
||||
raise TranscriptionError(f"Whisper transcription failed: {error_msg}") from e
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check if the connector is properly configured."""
|
||||
return bool(self.config.get('api_key'))
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> Dict[str, Any]:
|
||||
"""Return JSON schema for configuration."""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": ["api_key"],
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "OpenAI API key"
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "API base URL (optional, for OpenAI-compatible endpoints)"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"default": "whisper-1",
|
||||
"description": "Whisper model to use"
|
||||
}
|
||||
}
|
||||
}
|
||||
32
src/services/transcription/exceptions.py
Normal file
32
src/services/transcription/exceptions.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Custom exceptions for transcription services.
|
||||
"""
|
||||
|
||||
|
||||
class TranscriptionError(Exception):
|
||||
"""Base exception for transcription errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(TranscriptionError):
|
||||
"""Configuration-related errors (missing or invalid config)."""
|
||||
pass
|
||||
|
||||
|
||||
class ProviderError(TranscriptionError):
|
||||
"""Provider/API errors."""
|
||||
|
||||
def __init__(self, message: str, provider: str = None, status_code: int = None):
|
||||
super().__init__(message)
|
||||
self.provider = provider
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class AudioFormatError(TranscriptionError):
|
||||
"""Unsupported audio format errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ChunkingError(TranscriptionError):
|
||||
"""Errors during file chunking."""
|
||||
pass
|
||||
353
src/services/transcription/registry.py
Normal file
353
src/services/transcription/registry.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Connector registry for managing transcription connectors.
|
||||
|
||||
Provides factory pattern for creating and managing transcription connectors,
|
||||
with auto-detection from environment variables for backwards compatibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Type, List
|
||||
|
||||
from .base import BaseTranscriptionConnector, TranscriptionCapability, TranscriptionRequest, TranscriptionResponse
|
||||
from .exceptions import ConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorRegistry:
|
||||
"""
|
||||
Registry for managing transcription connectors.
|
||||
|
||||
Singleton pattern - use get_registry() to get the shared instance.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_connectors: Dict[str, Type[BaseTranscriptionConnector]] = {}
|
||||
_active_connector: Optional[BaseTranscriptionConnector] = None
|
||||
_connector_name: str = ""
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._register_builtin_connectors()
|
||||
self._initialized = True
|
||||
|
||||
def _register_builtin_connectors(self):
|
||||
"""Register all built-in connectors."""
|
||||
from .connectors.openai_whisper import OpenAIWhisperConnector
|
||||
from .connectors.openai_transcribe import OpenAITranscribeConnector
|
||||
from .connectors.asr_endpoint import ASREndpointConnector
|
||||
from .connectors.azure_openai_transcribe import AzureOpenAITranscribeConnector
|
||||
|
||||
self.register('openai_whisper', OpenAIWhisperConnector)
|
||||
self.register('openai_transcribe', OpenAITranscribeConnector)
|
||||
self.register('asr_endpoint', ASREndpointConnector)
|
||||
self.register('azure_openai_transcribe', AzureOpenAITranscribeConnector)
|
||||
|
||||
def register(self, name: str, connector_class: Type[BaseTranscriptionConnector]):
|
||||
"""
|
||||
Register a connector class.
|
||||
|
||||
Args:
|
||||
name: Unique name for the connector
|
||||
connector_class: The connector class to register
|
||||
"""
|
||||
self._connectors[name] = connector_class
|
||||
logger.debug(f"Registered transcription connector: {name}")
|
||||
|
||||
def get_connector_class(self, name: str) -> Type[BaseTranscriptionConnector]:
|
||||
"""
|
||||
Get a connector class by name.
|
||||
|
||||
Args:
|
||||
name: The connector name
|
||||
|
||||
Returns:
|
||||
The connector class
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If connector not found
|
||||
"""
|
||||
if name not in self._connectors:
|
||||
raise ConfigurationError(
|
||||
f"Unknown connector: {name}. Available: {list(self._connectors.keys())}"
|
||||
)
|
||||
return self._connectors[name]
|
||||
|
||||
def create_connector(self, name: str, config: Dict[str, Any]) -> BaseTranscriptionConnector:
|
||||
"""
|
||||
Create a connector instance.
|
||||
|
||||
Args:
|
||||
name: The connector name
|
||||
config: Configuration dict for the connector
|
||||
|
||||
Returns:
|
||||
Configured connector instance
|
||||
"""
|
||||
connector_class = self.get_connector_class(name)
|
||||
return connector_class(config)
|
||||
|
||||
def list_connectors(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all registered connectors with their capabilities.
|
||||
|
||||
Returns:
|
||||
List of connector info dicts
|
||||
"""
|
||||
result = []
|
||||
for name, cls in self._connectors.items():
|
||||
result.append({
|
||||
'name': name,
|
||||
'provider_name': cls.PROVIDER_NAME,
|
||||
'capabilities': [c.name for c in cls.CAPABILITIES],
|
||||
'config_schema': cls.get_config_schema()
|
||||
})
|
||||
return result
|
||||
|
||||
def initialize_from_env(self) -> BaseTranscriptionConnector:
|
||||
"""
|
||||
Initialize the active connector from environment variables.
|
||||
|
||||
Auto-detection priority:
|
||||
1. TRANSCRIPTION_CONNECTOR - explicit connector name
|
||||
2. ASR_BASE_URL is set - use ASR endpoint (smarter detection)
|
||||
- USE_ASR_ENDPOINT=true also works (backwards compat, with deprecation warning)
|
||||
3. TRANSCRIPTION_MODEL contains 'gpt-4o' - use OpenAI Transcribe
|
||||
4. TRANSCRIPTION_MODEL is set - use OpenAI Whisper with that model
|
||||
5. Default to OpenAI Whisper (whisper-1)
|
||||
|
||||
Returns:
|
||||
The initialized connector
|
||||
"""
|
||||
connector_name = os.environ.get('TRANSCRIPTION_CONNECTOR', '').lower().strip()
|
||||
|
||||
if not connector_name:
|
||||
# Auto-detect based on existing config for backwards compatibility
|
||||
asr_base_url = os.environ.get('ASR_BASE_URL', '').strip()
|
||||
use_asr_flag = os.environ.get('USE_ASR_ENDPOINT', 'false').lower() == 'true'
|
||||
transcription_model = os.environ.get('TRANSCRIPTION_MODEL', '').lower()
|
||||
whisper_model = os.environ.get('WHISPER_MODEL', '').lower()
|
||||
|
||||
# Deprecation warning for legacy USE_ASR_ENDPOINT flag
|
||||
if use_asr_flag:
|
||||
logger.warning(
|
||||
"USE_ASR_ENDPOINT=true is deprecated. "
|
||||
"Set ASR_BASE_URL instead for auto-detection, or use TRANSCRIPTION_CONNECTOR=asr_endpoint"
|
||||
)
|
||||
|
||||
# Priority 2: ASR endpoint - check ASR_BASE_URL or legacy flag
|
||||
if asr_base_url or use_asr_flag:
|
||||
connector_name = 'asr_endpoint'
|
||||
if asr_base_url:
|
||||
logger.info("Auto-detected ASR endpoint from ASR_BASE_URL")
|
||||
# Priority 2.5: Azure OpenAI - check for Azure endpoint URL
|
||||
elif self._is_azure_endpoint():
|
||||
connector_name = 'azure_openai_transcribe'
|
||||
logger.info("Auto-detected Azure OpenAI from TRANSCRIPTION_BASE_URL")
|
||||
# Priority 3: Model-based detection
|
||||
elif transcription_model and 'gpt-4o' in transcription_model:
|
||||
connector_name = 'openai_transcribe'
|
||||
logger.info(f"Auto-detected OpenAI Transcribe from TRANSCRIPTION_MODEL={transcription_model}")
|
||||
# Priority 4 & 5: OpenAI Whisper (with custom or default model)
|
||||
else:
|
||||
connector_name = 'openai_whisper'
|
||||
model = transcription_model or whisper_model or 'whisper-1'
|
||||
logger.info(f"Using OpenAI Whisper connector with model: {model}")
|
||||
|
||||
config = self._build_config_from_env(connector_name)
|
||||
|
||||
try:
|
||||
self._active_connector = self.create_connector(connector_name, config)
|
||||
self._connector_name = connector_name
|
||||
|
||||
logger.info(f"Initialized transcription connector: {connector_name}")
|
||||
logger.info(f"Capabilities: {[c.name for c in self._active_connector.get_capabilities()]}")
|
||||
|
||||
return self._active_connector
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize connector '{connector_name}': {e}")
|
||||
raise ConfigurationError(f"Failed to initialize connector '{connector_name}': {e}") from e
|
||||
|
||||
def _get_asr_timeout(self) -> int:
|
||||
"""
|
||||
Get ASR timeout with fallback chain: ENV -> Admin UI -> default.
|
||||
|
||||
Priority:
|
||||
1. ASR_TIMEOUT environment variable
|
||||
2. asr_timeout_seconds environment variable (legacy)
|
||||
3. SystemSetting from Admin UI (database)
|
||||
4. Default: 1800 seconds (30 minutes)
|
||||
"""
|
||||
# Check environment variables first
|
||||
env_timeout = os.environ.get('ASR_TIMEOUT') or os.environ.get('asr_timeout_seconds')
|
||||
if env_timeout:
|
||||
return int(env_timeout)
|
||||
|
||||
# Fall back to Admin UI setting (SystemSetting in database)
|
||||
try:
|
||||
from src.models import SystemSetting
|
||||
db_timeout = SystemSetting.get_setting('asr_timeout_seconds', None)
|
||||
if db_timeout is not None:
|
||||
return int(db_timeout)
|
||||
except Exception as e:
|
||||
# May fail if no app context or during initialization
|
||||
logger.debug(f"Could not read ASR timeout from database: {e}")
|
||||
|
||||
# Default: 30 minutes
|
||||
return 1800
|
||||
|
||||
def _is_azure_endpoint(self) -> bool:
|
||||
"""Check if the TRANSCRIPTION_BASE_URL points to an Azure OpenAI endpoint."""
|
||||
base_url = os.environ.get('TRANSCRIPTION_BASE_URL', '').lower()
|
||||
return '.openai.azure.com' in base_url or '.cognitiveservices.azure.com' in base_url
|
||||
|
||||
def _build_config_from_env(self, connector_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Build connector config from environment variables.
|
||||
|
||||
Args:
|
||||
connector_name: The connector to build config for
|
||||
|
||||
Returns:
|
||||
Configuration dict
|
||||
"""
|
||||
if connector_name == 'asr_endpoint':
|
||||
base_url = os.environ.get('ASR_BASE_URL', '')
|
||||
if base_url:
|
||||
base_url = base_url.split('#')[0].strip()
|
||||
|
||||
return {
|
||||
'base_url': base_url,
|
||||
'timeout': self._get_asr_timeout(),
|
||||
'diarize': os.environ.get('ASR_DIARIZE', 'true').lower() == 'true',
|
||||
'return_speaker_embeddings': os.environ.get('ASR_RETURN_SPEAKER_EMBEDDINGS', 'false').lower() == 'true'
|
||||
}
|
||||
|
||||
elif connector_name == 'openai_transcribe':
|
||||
base_url = os.environ.get('TRANSCRIPTION_BASE_URL', 'https://api.openai.com/v1')
|
||||
if base_url:
|
||||
base_url = base_url.split('#')[0].strip()
|
||||
|
||||
return {
|
||||
'api_key': os.environ.get('TRANSCRIPTION_API_KEY', ''),
|
||||
'base_url': base_url,
|
||||
'model': os.environ.get('TRANSCRIPTION_MODEL', 'gpt-4o-transcribe')
|
||||
}
|
||||
|
||||
elif connector_name == 'azure_openai_transcribe':
|
||||
# Azure OpenAI requires endpoint and deployment_name
|
||||
# TRANSCRIPTION_BASE_URL should be the Azure endpoint (e.g., https://your-resource.openai.azure.com)
|
||||
endpoint = os.environ.get('TRANSCRIPTION_BASE_URL', '')
|
||||
if endpoint:
|
||||
endpoint = endpoint.split('#')[0].strip()
|
||||
# Remove any trailing /openai or /v1 paths - we build the full URL ourselves
|
||||
endpoint = endpoint.rstrip('/')
|
||||
for suffix in ['/openai/v1', '/openai', '/v1']:
|
||||
if endpoint.lower().endswith(suffix):
|
||||
endpoint = endpoint[:-len(suffix)]
|
||||
|
||||
return {
|
||||
'api_key': os.environ.get('TRANSCRIPTION_API_KEY', ''),
|
||||
'endpoint': endpoint,
|
||||
'deployment_name': os.environ.get('AZURE_DEPLOYMENT_NAME', os.environ.get('TRANSCRIPTION_MODEL', 'gpt-4o-transcribe')),
|
||||
'api_version': os.environ.get('AZURE_API_VERSION', '2025-04-01-preview'),
|
||||
'model': os.environ.get('TRANSCRIPTION_MODEL', '') # For capability detection
|
||||
}
|
||||
|
||||
else: # openai_whisper (default)
|
||||
base_url = os.environ.get('TRANSCRIPTION_BASE_URL', '')
|
||||
if base_url:
|
||||
base_url = base_url.split('#')[0].strip()
|
||||
|
||||
# Support both TRANSCRIPTION_MODEL and legacy WHISPER_MODEL
|
||||
# TRANSCRIPTION_MODEL takes priority for custom Whisper variants
|
||||
model = os.environ.get('TRANSCRIPTION_MODEL', '') or os.environ.get('WHISPER_MODEL', 'whisper-1')
|
||||
|
||||
return {
|
||||
'api_key': os.environ.get('TRANSCRIPTION_API_KEY', ''),
|
||||
'base_url': base_url or None,
|
||||
'model': model
|
||||
}
|
||||
|
||||
def get_active_connector(self) -> BaseTranscriptionConnector:
|
||||
"""
|
||||
Get the currently active connector.
|
||||
|
||||
Initializes from environment if not already initialized.
|
||||
|
||||
Returns:
|
||||
The active connector
|
||||
"""
|
||||
if not self._active_connector:
|
||||
self.initialize_from_env()
|
||||
return self._active_connector
|
||||
|
||||
def get_active_connector_name(self) -> str:
|
||||
"""Get the name of the currently active connector."""
|
||||
if not self._active_connector:
|
||||
self.initialize_from_env()
|
||||
return self._connector_name
|
||||
|
||||
def reinitialize(self) -> BaseTranscriptionConnector:
|
||||
"""
|
||||
Force re-initialization of the connector.
|
||||
|
||||
Useful when environment variables have changed.
|
||||
|
||||
Returns:
|
||||
The newly initialized connector
|
||||
"""
|
||||
self._active_connector = None
|
||||
self._connector_name = ""
|
||||
return self.initialize_from_env()
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ConnectorRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> ConnectorRegistry:
|
||||
"""Get the global connector registry."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = ConnectorRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
# Convenience aliases
|
||||
connector_registry = get_registry()
|
||||
|
||||
|
||||
def transcribe(request: TranscriptionRequest) -> TranscriptionResponse:
|
||||
"""
|
||||
Transcribe audio using the active connector.
|
||||
|
||||
This is a convenience function that uses the global registry.
|
||||
|
||||
Args:
|
||||
request: The transcription request
|
||||
|
||||
Returns:
|
||||
Transcription response
|
||||
"""
|
||||
connector = get_registry().get_active_connector()
|
||||
return connector.transcribe(request)
|
||||
|
||||
|
||||
def get_connector() -> BaseTranscriptionConnector:
|
||||
"""Get the active transcription connector."""
|
||||
return get_registry().get_active_connector()
|
||||
|
||||
|
||||
def supports_diarization() -> bool:
|
||||
"""Check if the active connector supports diarization."""
|
||||
return get_registry().get_active_connector().supports_diarization
|
||||
312
src/services/transcription_tracking.py
Normal file
312
src/services/transcription_tracking.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Transcription usage tracking service for monitoring audio transcription consumption and budget enforcement.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Tuple, Optional, Dict, List
|
||||
|
||||
from sqlalchemy import func, extract
|
||||
|
||||
from src.database import db
|
||||
from src.models.transcription_usage import TranscriptionUsage
|
||||
from src.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Pricing configuration per connector/model (dollars per minute)
|
||||
TRANSCRIPTION_PRICING = {
|
||||
'openai_whisper': {
|
||||
'whisper-1': 0.006, # $0.006/min
|
||||
'default': 0.006,
|
||||
},
|
||||
'openai_transcribe': {
|
||||
'gpt-4o-transcribe': 0.006, # $0.006/min
|
||||
'gpt-4o-mini-transcribe': 0.003, # $0.003/min
|
||||
'gpt-4o-mini-transcribe-2025-12-15': 0.003,
|
||||
'gpt-4o-transcribe-diarize': 0.006,
|
||||
'default': 0.006,
|
||||
},
|
||||
'asr_endpoint': {
|
||||
'default': 0.0, # Self-hosted = free
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_transcription_cost_per_minute(connector_type: str, model_name: str = None) -> float:
|
||||
"""
|
||||
Get the cost per minute for a given connector and model.
|
||||
|
||||
Args:
|
||||
connector_type: The connector provider name
|
||||
model_name: The specific model (optional)
|
||||
|
||||
Returns:
|
||||
Cost per minute in dollars
|
||||
"""
|
||||
connector_pricing = TRANSCRIPTION_PRICING.get(connector_type, {})
|
||||
|
||||
if model_name and model_name in connector_pricing:
|
||||
return connector_pricing[model_name]
|
||||
|
||||
# Fall back to 'default' pricing for the connector
|
||||
return connector_pricing.get('default', 0.0)
|
||||
|
||||
|
||||
class TranscriptionTracker:
|
||||
"""Service for recording and checking transcription usage."""
|
||||
|
||||
CONNECTOR_TYPES = [
|
||||
'openai_whisper',
|
||||
'openai_transcribe',
|
||||
'asr_endpoint',
|
||||
]
|
||||
|
||||
def record_usage(
|
||||
self,
|
||||
user_id: int,
|
||||
connector_type: str,
|
||||
audio_duration_seconds: int,
|
||||
model_name: str = None,
|
||||
estimated_cost: float = None
|
||||
):
|
||||
"""
|
||||
Record transcription usage - upserts into daily aggregate.
|
||||
|
||||
Args:
|
||||
user_id: User ID who made the request
|
||||
connector_type: Type of connector (openai_whisper, openai_transcribe, asr_endpoint)
|
||||
audio_duration_seconds: Duration of audio transcribed in seconds
|
||||
model_name: Name of the model used
|
||||
estimated_cost: Pre-calculated cost (if None, calculated from pricing config)
|
||||
"""
|
||||
try:
|
||||
today = date.today()
|
||||
|
||||
# Calculate cost if not provided
|
||||
if estimated_cost is None:
|
||||
cost_per_minute = get_transcription_cost_per_minute(connector_type, model_name)
|
||||
estimated_cost = (audio_duration_seconds / 60.0) * cost_per_minute
|
||||
|
||||
# Find or create today's record for this user/connector
|
||||
usage = TranscriptionUsage.query.filter_by(
|
||||
user_id=user_id,
|
||||
date=today,
|
||||
connector_type=connector_type
|
||||
).first()
|
||||
|
||||
if usage:
|
||||
# Update existing record
|
||||
usage.audio_duration_seconds += audio_duration_seconds
|
||||
usage.estimated_cost += estimated_cost
|
||||
usage.request_count += 1
|
||||
if model_name:
|
||||
usage.model_name = model_name # Update to latest model used
|
||||
else:
|
||||
# Create new record
|
||||
usage = TranscriptionUsage(
|
||||
user_id=user_id,
|
||||
date=today,
|
||||
connector_type=connector_type,
|
||||
audio_duration_seconds=audio_duration_seconds,
|
||||
request_count=1,
|
||||
model_name=model_name,
|
||||
estimated_cost=estimated_cost or 0.0
|
||||
)
|
||||
db.session.add(usage)
|
||||
|
||||
db.session.commit()
|
||||
logger.debug(f"Recorded {audio_duration_seconds}s transcription for user {user_id}, connector {connector_type}")
|
||||
return usage
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record transcription usage: {e}")
|
||||
db.session.rollback()
|
||||
return None
|
||||
|
||||
def get_monthly_usage(self, user_id: int, year: int = None, month: int = None) -> int:
|
||||
"""Get total seconds transcribed by a user in a given month."""
|
||||
if year is None:
|
||||
year = date.today().year
|
||||
if month is None:
|
||||
month = date.today().month
|
||||
|
||||
result = db.session.query(func.sum(TranscriptionUsage.audio_duration_seconds)).filter(
|
||||
TranscriptionUsage.user_id == user_id,
|
||||
extract('year', TranscriptionUsage.date) == year,
|
||||
extract('month', TranscriptionUsage.date) == month
|
||||
).scalar()
|
||||
|
||||
return result or 0
|
||||
|
||||
def get_monthly_cost(self, user_id: int, year: int = None, month: int = None) -> float:
|
||||
"""Get total estimated cost for a user in a given month."""
|
||||
if year is None:
|
||||
year = date.today().year
|
||||
if month is None:
|
||||
month = date.today().month
|
||||
|
||||
result = db.session.query(func.sum(TranscriptionUsage.estimated_cost)).filter(
|
||||
TranscriptionUsage.user_id == user_id,
|
||||
extract('year', TranscriptionUsage.date) == year,
|
||||
extract('month', TranscriptionUsage.date) == month
|
||||
).scalar()
|
||||
|
||||
return result or 0.0
|
||||
|
||||
def check_budget(self, user_id: int) -> Tuple[bool, float, Optional[str]]:
|
||||
"""
|
||||
Check if user is within transcription budget.
|
||||
|
||||
Returns:
|
||||
(can_proceed, usage_percentage, message)
|
||||
- can_proceed: False if hard cap (100%) reached
|
||||
- usage_percentage: 0-100+
|
||||
- message: Warning/error message if applicable
|
||||
"""
|
||||
try:
|
||||
user = db.session.get(User, user_id)
|
||||
if not user or not user.monthly_transcription_budget:
|
||||
return (True, 0, None) # No budget = unlimited
|
||||
|
||||
current_usage = self.get_monthly_usage(user_id)
|
||||
budget = user.monthly_transcription_budget
|
||||
percentage = (current_usage / budget) * 100
|
||||
|
||||
if percentage >= 100:
|
||||
minutes_used = current_usage // 60
|
||||
minutes_budget = budget // 60
|
||||
return (False, percentage,
|
||||
f"Monthly transcription budget exceeded ({minutes_used}/{minutes_budget} minutes). Contact admin for more time.")
|
||||
elif percentage >= 80:
|
||||
return (True, percentage,
|
||||
f"Warning: {percentage:.1f}% of monthly transcription budget used.")
|
||||
else:
|
||||
return (True, percentage, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check transcription budget for user {user_id}: {e}")
|
||||
# Fail open - allow the request if we can't check
|
||||
return (True, 0, None)
|
||||
|
||||
def get_daily_stats(self, days: int = 30, user_id: int = None) -> List[Dict]:
|
||||
"""Get daily transcription usage for charts."""
|
||||
start_date = date.today() - timedelta(days=days - 1)
|
||||
|
||||
query = db.session.query(
|
||||
TranscriptionUsage.date,
|
||||
TranscriptionUsage.connector_type,
|
||||
func.sum(TranscriptionUsage.audio_duration_seconds).label('seconds'),
|
||||
func.sum(TranscriptionUsage.estimated_cost).label('cost')
|
||||
).filter(TranscriptionUsage.date >= start_date)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(TranscriptionUsage.user_id == user_id)
|
||||
|
||||
results = query.group_by(TranscriptionUsage.date, TranscriptionUsage.connector_type).all()
|
||||
|
||||
# Organize by date
|
||||
stats = {}
|
||||
for r in results:
|
||||
date_str = r.date.isoformat()
|
||||
if date_str not in stats:
|
||||
stats[date_str] = {'date': date_str, 'total_seconds': 0, 'total_minutes': 0, 'cost': 0.0, 'by_connector': {}}
|
||||
stats[date_str]['total_seconds'] += r.seconds or 0
|
||||
stats[date_str]['total_minutes'] = stats[date_str]['total_seconds'] // 60
|
||||
stats[date_str]['cost'] += r.cost or 0
|
||||
stats[date_str]['by_connector'][r.connector_type] = {
|
||||
'seconds': r.seconds or 0,
|
||||
'minutes': (r.seconds or 0) // 60
|
||||
}
|
||||
|
||||
# Fill in missing dates with zeros
|
||||
all_dates = []
|
||||
current = start_date
|
||||
while current <= date.today():
|
||||
date_str = current.isoformat()
|
||||
if date_str not in stats:
|
||||
stats[date_str] = {'date': date_str, 'total_seconds': 0, 'total_minutes': 0, 'cost': 0.0, 'by_connector': {}}
|
||||
all_dates.append(date_str)
|
||||
current += timedelta(days=1)
|
||||
|
||||
return [stats[d] for d in sorted(all_dates)]
|
||||
|
||||
def get_monthly_stats(self, months: int = 12) -> List[Dict]:
|
||||
"""Get monthly transcription usage for charts."""
|
||||
results = db.session.query(
|
||||
extract('year', TranscriptionUsage.date).label('year'),
|
||||
extract('month', TranscriptionUsage.date).label('month'),
|
||||
func.sum(TranscriptionUsage.audio_duration_seconds).label('seconds'),
|
||||
func.sum(TranscriptionUsage.estimated_cost).label('cost')
|
||||
).group_by('year', 'month').order_by('year', 'month').all()
|
||||
|
||||
# Get last N months
|
||||
monthly_data = [
|
||||
{
|
||||
'year': int(r.year),
|
||||
'month': int(r.month),
|
||||
'seconds': r.seconds or 0,
|
||||
'minutes': (r.seconds or 0) // 60,
|
||||
'cost': r.cost or 0
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return monthly_data[-months:] if len(monthly_data) > months else monthly_data
|
||||
|
||||
def get_user_stats(self) -> List[Dict]:
|
||||
"""Get per-user transcription usage breakdown for current month."""
|
||||
today = date.today()
|
||||
|
||||
results = db.session.query(
|
||||
User.id,
|
||||
User.username,
|
||||
User.monthly_transcription_budget,
|
||||
func.sum(TranscriptionUsage.audio_duration_seconds).label('usage'),
|
||||
func.sum(TranscriptionUsage.estimated_cost).label('cost')
|
||||
).outerjoin(
|
||||
TranscriptionUsage,
|
||||
(User.id == TranscriptionUsage.user_id) &
|
||||
(extract('year', TranscriptionUsage.date) == today.year) &
|
||||
(extract('month', TranscriptionUsage.date) == today.month)
|
||||
).group_by(User.id).all()
|
||||
|
||||
return [
|
||||
{
|
||||
'user_id': r.id,
|
||||
'username': r.username,
|
||||
'monthly_budget_seconds': r.monthly_transcription_budget,
|
||||
'monthly_budget_minutes': (r.monthly_transcription_budget // 60) if r.monthly_transcription_budget else None,
|
||||
'current_usage_seconds': r.usage or 0,
|
||||
'current_usage_minutes': (r.usage or 0) // 60,
|
||||
'cost': r.cost or 0,
|
||||
'percentage': ((r.usage or 0) / r.monthly_transcription_budget * 100)
|
||||
if r.monthly_transcription_budget else 0
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
def get_today_usage(self, user_id: int = None) -> Dict:
|
||||
"""Get today's transcription usage."""
|
||||
today = date.today()
|
||||
|
||||
query = db.session.query(
|
||||
func.sum(TranscriptionUsage.audio_duration_seconds).label('seconds'),
|
||||
func.sum(TranscriptionUsage.estimated_cost).label('cost')
|
||||
).filter(TranscriptionUsage.date == today)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(TranscriptionUsage.user_id == user_id)
|
||||
|
||||
result = query.first()
|
||||
|
||||
return {
|
||||
'seconds': result.seconds or 0,
|
||||
'minutes': (result.seconds or 0) // 60,
|
||||
'cost': result.cost or 0
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
transcription_tracker = TranscriptionTracker()
|
||||
Reference in New Issue
Block a user