Initial release: DictIA v0.8.14-alpha (fork de Speakr, AGPL-3.0)

This commit is contained in:
InnovA AI
2026-03-16 21:47:37 +00:00
commit 42772a31ed
365 changed files with 103572 additions and 0 deletions

View File

@@ -0,0 +1,971 @@
#!/usr/bin/env python3
"""
Test suite for Speaker API v1 endpoints.
Covers:
- PUT /recordings/<id>/speakers/assign (17 tests)
- POST /recordings/<id>/speakers/identify (10 tests)
- PUT /settings/auto-summarization (5 tests)
- Regression for GET /speakers and GET /recordings/<id>/speakers (2 tests)
Pattern follows tests/test_api_v1_upload.py — standalone, no pytest fixtures.
"""
import json
import secrets
import sys
import os
from unittest.mock import patch, MagicMock
# Add parent directory so we can import the app
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.app import app, db
from src.models import User, APIToken, Recording, Speaker
from src.utils.token_auth import hash_token
# ---------------------------------------------------------------------------
# Test data constants
# ---------------------------------------------------------------------------
SAMPLE_TRANSCRIPTION_JSON = json.dumps([
{"speaker": "SPEAKER_00", "sentence": "Hi, I'm Alice."},
{"speaker": "SPEAKER_01", "sentence": "Hello Alice, I'm Bob."},
{"speaker": "SPEAKER_00", "sentence": "Nice to meet you, Bob."},
])
SAMPLE_TRANSCRIPTION_TEXT = (
"[SPEAKER_00]: Hi, I'm Alice.\n"
"[SPEAKER_01]: Hello Alice, I'm Bob.\n"
"[SPEAKER_00]: Nice to meet you, Bob."
)
SAMPLE_EMBEDDINGS = json.dumps({
"SPEAKER_00": [0.1] * 256,
"SPEAKER_01": [0.2] * 256,
})
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_or_create_test_user(suffix=""):
"""Get or create a test user. Returns (user, created_bool)."""
username = f"speaker_test_user{suffix}"
user = User.query.filter_by(username=username).first()
created = False
if not user:
user = User(
username=username,
email=f"{username}@local.test",
name="Test User" if not suffix else None,
)
db.session.add(user)
db.session.commit()
created = True
return user, created
def _create_api_token(user):
"""Create a fresh API token. Returns (token_record, plaintext)."""
plaintext = f"test-token-{secrets.token_urlsafe(16)}"
token = APIToken(
user_id=user.id,
token_hash=hash_token(plaintext),
name="test-api-token",
)
db.session.add(token)
db.session.commit()
return token, plaintext
def _create_test_recording(user, transcription=None, speaker_embeddings=None, status="COMPLETED"):
"""Create a Recording owned by *user*."""
rec = Recording(
user_id=user.id,
title="Test Recording",
status=status,
transcription=transcription,
speaker_embeddings=speaker_embeddings,
)
db.session.add(rec)
db.session.commit()
return rec
def _create_test_speaker(user, name="Alice"):
"""Create a Speaker owned by *user*."""
speaker = Speaker(name=name, user_id=user.id)
db.session.add(speaker)
db.session.commit()
return speaker
def _cleanup(*objects):
"""Delete DB objects in reverse order, committing once."""
for obj in reversed(objects):
try:
db.session.delete(obj)
except Exception:
db.session.rollback()
try:
merged = db.session.merge(obj)
db.session.delete(merged)
except Exception:
pass
db.session.commit()
# =========================================================================
# Group 1: PUT /recordings/<id>/speakers/assign (17 tests)
# =========================================================================
def test_assign_no_auth():
"""No token -> 302 redirect (Flask-Login)."""
with app.app_context():
user, cu = _get_or_create_test_user()
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
json={"speaker_map": {}})
assert resp.status_code in (302, 401), f"Expected 302/401, got {resp.status_code}"
return True
finally:
_cleanup(rec)
if cu:
_cleanup(user)
def test_assign_recording_not_found():
"""Nonexistent recording ID -> 404."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
client = app.test_client()
try:
resp = client.put("/api/v1/recordings/999999/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {}})
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
return True
finally:
_cleanup(token_rec)
if cu:
_cleanup(user)
def test_assign_wrong_user_recording():
"""Other user's recording -> 403."""
with app.app_context():
owner, co = _get_or_create_test_user("_owner")
other, cu = _get_or_create_test_user("_other")
token_rec, token = _create_api_token(other)
rec = _create_test_recording(owner, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": "Alice"}})
assert resp.status_code == 403, f"Expected 403, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(other)
if co:
_cleanup(owner)
def test_assign_missing_speaker_map():
"""Body {} -> 400 'speaker_map is required'."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
body = resp.get_json()
assert "speaker_map" in body.get("error", "").lower(), f"Unexpected error: {body}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_invalid_speaker_map_type():
"""speaker_map: 'string' -> 400."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": "not a dict"})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_string_value_json_transcript():
"""Happy path: string names update JSON segments + participants."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": "Alice", "SPEAKER_01": "Bob"}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
assert body.get("success") is True
# Verify participants
participants = body["recording"]["participants"]
assert "Alice" in participants and "Bob" in participants
# Verify transcription was updated
db.session.refresh(rec)
segments = json.loads(rec.transcription)
assert segments[0]["speaker"] == "Alice"
assert segments[1]["speaker"] == "Bob"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_object_value_with_name():
"""Happy path: {name, isMe} object format."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {
"SPEAKER_00": {"name": "Alice", "isMe": False},
}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
db.session.refresh(rec)
segments = json.loads(rec.transcription)
assert segments[0]["speaker"] == "Alice"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_is_me_flag_with_user_name():
"""isMe: true resolves to user.name."""
with app.app_context():
user, cu = _get_or_create_test_user() # user.name == "Test User"
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {
"SPEAKER_00": {"name": "", "isMe": True},
}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
db.session.refresh(rec)
segments = json.loads(rec.transcription)
assert segments[0]["speaker"] == "Test User", f"Got {segments[0]['speaker']}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_is_me_flag_without_user_name():
"""isMe: true falls back to 'Me' when user.name is None."""
with app.app_context():
user, cu = _get_or_create_test_user("_noname")
# Ensure user.name is None
user.name = None
db.session.commit()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {
"SPEAKER_00": {"name": "", "isMe": True},
}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
db.session.refresh(rec)
segments = json.loads(rec.transcription)
assert segments[0]["speaker"] == "Me", f"Got {segments[0]['speaker']}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_plain_text_transcript():
"""Replaces [SPEAKER_XX] in plain text format."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_TEXT)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": "Alice", "SPEAKER_01": "Bob"}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
db.session.refresh(rec)
assert "[Alice]" in rec.transcription
assert "[Bob]" in rec.transcription
assert "[SPEAKER_00]" not in rec.transcription
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_speaker_xx_filtered_from_participants():
"""Unresolved SPEAKER_XX labels excluded from participants."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
# Only assign one speaker - SPEAKER_01 stays unresolved
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": "Alice"}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
participants = body["recording"]["participants"]
assert "SPEAKER_01" not in participants, f"SPEAKER_01 should be filtered: {participants}"
assert "Alice" in participants
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_invalid_value_type():
"""Array value -> 400 'Invalid value type'."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": [1, 2, 3]}})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
body = resp.get_json()
assert "invalid value type" in body.get("error", "").lower(), f"Unexpected: {body}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_empty_speaker_map():
"""Empty speaker_map {} -> 200 with no changes."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
assert body.get("success") is True
# Transcription should be unchanged
db.session.refresh(rec)
segments = json.loads(rec.transcription)
assert segments[0]["speaker"] == "SPEAKER_00"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_regenerate_summary():
"""regenerate_summary: true -> job_queue.enqueue called, summary_queued: true."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
mock_jq = MagicMock()
mock_jq.enqueue = MagicMock(return_value="job-123")
with patch("src.services.job_queue.job_queue", mock_jq):
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={
"speaker_map": {"SPEAKER_00": "Alice"},
"regenerate_summary": True,
})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
assert body.get("summary_queued") is True
mock_jq.enqueue.assert_called_once()
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_embeddings_updated():
"""With speaker_embeddings -> update_speaker_embedding called, counts returned."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(
user,
transcription=SAMPLE_TRANSCRIPTION_JSON,
speaker_embeddings=SAMPLE_EMBEDDINGS,
)
speaker = _create_test_speaker(user, "Alice")
client = app.test_client()
try:
mock_update = MagicMock()
mock_snippets = MagicMock(return_value=2)
with patch("src.services.speaker_embedding_matcher.update_speaker_embedding", mock_update), \
patch("src.services.speaker_snippets.create_speaker_snippets", mock_snippets):
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": "Alice"}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
assert body.get("embeddings_updated") >= 1, f"embeddings_updated: {body}"
mock_update.assert_called()
return True
finally:
_cleanup(rec, speaker, token_rec)
if cu:
_cleanup(user)
def test_assign_no_transcription():
"""Recording without transcription -> speakers applied to empty content gracefully."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=None)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": "Alice"}})
# Should succeed (or at least not 500)
assert resp.status_code in (200, 400), f"Expected 200/400, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_assign_whitespace_name_trimmed():
"""Names with leading/trailing whitespace get trimmed."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
headers={"X-API-Token": token},
json={"speaker_map": {"SPEAKER_00": " Alice "}})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
db.session.refresh(rec)
segments = json.loads(rec.transcription)
assert segments[0]["speaker"] == "Alice", f"Name not trimmed: '{segments[0]['speaker']}'"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
# =========================================================================
# Group 2: POST /recordings/<id>/speakers/identify (10 tests)
# =========================================================================
def test_identify_no_auth():
"""No token -> 302."""
with app.app_context():
user, cu = _get_or_create_test_user()
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify")
assert resp.status_code in (302, 401), f"Expected 302/401, got {resp.status_code}"
return True
finally:
_cleanup(rec)
if cu:
_cleanup(user)
def test_identify_recording_not_found():
"""Nonexistent ID -> 404."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
client = app.test_client()
try:
resp = client.post("/api/v1/recordings/999999/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
return True
finally:
_cleanup(token_rec)
if cu:
_cleanup(user)
def test_identify_wrong_user_recording():
"""Other user's recording -> 403."""
with app.app_context():
owner, co = _get_or_create_test_user("_id_owner")
other, cu = _get_or_create_test_user("_id_other")
token_rec, token = _create_api_token(other)
rec = _create_test_recording(owner, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 403, f"Expected 403, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(other)
if co:
_cleanup(owner)
def test_identify_no_transcription():
"""No transcription -> 400."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=None)
client = app.test_client()
try:
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_identify_non_json_transcription():
"""Plain text -> 400."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_TEXT)
client = app.test_client()
try:
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_identify_json_but_not_list():
"""Dict JSON -> 400."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=json.dumps({"key": "value"}))
client = app.test_client()
try:
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_identify_happy_path():
"""Mock LLM returns names -> 200 with speaker_map."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
# Build a mock LLM completion response
mock_completion = MagicMock()
mock_completion.choices = [MagicMock()]
mock_completion.choices[0].message.content = json.dumps({
"SPEAKER_00": "Alice",
"SPEAKER_01": "Bob",
})
with patch("src.services.llm.call_llm_completion", return_value=mock_completion), \
patch("src.models.system.SystemSetting") as mock_ss:
mock_ss.get_setting.return_value = 30000
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
assert body.get("success") is True
sm = body.get("speaker_map", {})
assert sm.get("SPEAKER_00") == "Alice"
assert sm.get("SPEAKER_01") == "Bob"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_identify_post_processing_unknown_values():
"""'Unknown'/'N/A' cleared to ''."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
mock_completion = MagicMock()
mock_completion.choices = [MagicMock()]
mock_completion.choices[0].message.content = json.dumps({
"SPEAKER_00": "Unknown",
"SPEAKER_01": "N/A",
})
with patch("src.services.llm.call_llm_completion", return_value=mock_completion), \
patch("src.models.system.SystemSetting") as mock_ss:
mock_ss.get_setting.return_value = 30000
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
sm = body.get("speaker_map", {})
assert sm.get("SPEAKER_00") == "", f"Expected empty, got {sm.get('SPEAKER_00')}"
assert sm.get("SPEAKER_01") == "", f"Expected empty, got {sm.get('SPEAKER_01')}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_identify_no_speakers_in_transcript():
"""Segments without speaker field -> 400."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
no_speakers = json.dumps([{"sentence": "Hello"}, {"sentence": "World"}])
rec = _create_test_recording(user, transcription=no_speakers)
client = app.test_client()
try:
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
def test_identify_llm_error():
"""LLM raises exception -> 500."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
with patch("src.services.llm.call_llm_completion",
side_effect=RuntimeError("LLM down")), \
patch("src.models.system.SystemSetting") as mock_ss:
mock_ss.get_setting.return_value = 30000
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
headers={"X-API-Token": token})
assert resp.status_code == 500, f"Expected 500, got {resp.status_code}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
# =========================================================================
# Group 3: PUT /settings/auto-summarization (5 tests)
# =========================================================================
def test_auto_summarization_no_auth():
"""No token -> 302."""
with app.app_context():
client = app.test_client()
resp = client.put("/api/v1/settings/auto-summarization",
json={"enabled": True})
assert resp.status_code in (302, 401), f"Expected 302/401, got {resp.status_code}"
return True
def test_auto_summarization_missing_enabled():
"""Body {} -> 400."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
client = app.test_client()
try:
resp = client.put("/api/v1/settings/auto-summarization",
headers={"X-API-Token": token},
json={})
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
body = resp.get_json()
assert "enabled" in body.get("error", "").lower(), f"Unexpected: {body}"
return True
finally:
_cleanup(token_rec)
if cu:
_cleanup(user)
def test_auto_summarization_invalid_json():
"""Non-JSON body -> 400."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
client = app.test_client()
try:
resp = client.put("/api/v1/settings/auto-summarization",
headers={"X-API-Token": token,
"Content-Type": "application/json"},
data="not valid json")
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
return True
finally:
_cleanup(token_rec)
if cu:
_cleanup(user)
def test_auto_summarization_enable():
"""enabled: true -> updates user, returns true."""
with app.app_context():
user, cu = _get_or_create_test_user()
user.auto_summarization = False
db.session.commit()
token_rec, token = _create_api_token(user)
client = app.test_client()
try:
resp = client.put("/api/v1/settings/auto-summarization",
headers={"X-API-Token": token},
json={"enabled": True})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
assert body.get("auto_summarization") is True
db.session.refresh(user)
assert user.auto_summarization is True
return True
finally:
_cleanup(token_rec)
if cu:
_cleanup(user)
def test_auto_summarization_disable():
"""enabled: false -> updates user, returns false."""
with app.app_context():
user, cu = _get_or_create_test_user()
user.auto_summarization = True
db.session.commit()
token_rec, token = _create_api_token(user)
client = app.test_client()
try:
resp = client.put("/api/v1/settings/auto-summarization",
headers={"X-API-Token": token},
json={"enabled": False})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
assert body.get("auto_summarization") is False
db.session.refresh(user)
assert user.auto_summarization is False
return True
finally:
_cleanup(token_rec)
if cu:
_cleanup(user)
# =========================================================================
# Group 4: Regression tests (2 tests)
# =========================================================================
def test_regression_get_speakers_list():
"""GET /speakers still returns user's speakers."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
speaker = _create_test_speaker(user, "Regression Speaker")
client = app.test_client()
try:
resp = client.get("/api/v1/speakers",
headers={"X-API-Token": token})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
names = [s["name"] for s in body.get("speakers", [])]
assert "Regression Speaker" in names, f"Speaker not found: {names}"
return True
finally:
_cleanup(speaker, token_rec)
if cu:
_cleanup(user)
def test_regression_get_recording_speakers():
"""GET /recordings/<id>/speakers still returns transcript speakers."""
with app.app_context():
user, cu = _get_or_create_test_user()
token_rec, token = _create_api_token(user)
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
client = app.test_client()
try:
with patch("src.services.speaker_embedding_matcher.find_matching_speakers", return_value={}):
resp = client.get(f"/api/v1/recordings/{rec.id}/speakers",
headers={"X-API-Token": token})
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
body = resp.get_json()
labels = [s["label"] for s in body.get("speakers", [])]
assert "SPEAKER_00" in labels and "SPEAKER_01" in labels, f"Labels: {labels}"
return True
finally:
_cleanup(rec, token_rec)
if cu:
_cleanup(user)
# =========================================================================
# Runner
# =========================================================================
ALL_TESTS = [
# Group 1: assign
test_assign_no_auth,
test_assign_recording_not_found,
test_assign_wrong_user_recording,
test_assign_missing_speaker_map,
test_assign_invalid_speaker_map_type,
test_assign_string_value_json_transcript,
test_assign_object_value_with_name,
test_assign_is_me_flag_with_user_name,
test_assign_is_me_flag_without_user_name,
test_assign_plain_text_transcript,
test_assign_speaker_xx_filtered_from_participants,
test_assign_invalid_value_type,
test_assign_empty_speaker_map,
test_assign_regenerate_summary,
test_assign_embeddings_updated,
test_assign_no_transcription,
test_assign_whitespace_name_trimmed,
# Group 2: identify
test_identify_no_auth,
test_identify_recording_not_found,
test_identify_wrong_user_recording,
test_identify_no_transcription,
test_identify_non_json_transcription,
test_identify_json_but_not_list,
test_identify_happy_path,
test_identify_post_processing_unknown_values,
test_identify_no_speakers_in_transcript,
test_identify_llm_error,
# Group 3: auto-summarization
test_auto_summarization_no_auth,
test_auto_summarization_missing_enabled,
test_auto_summarization_invalid_json,
test_auto_summarization_enable,
test_auto_summarization_disable,
# Group 4: regression
test_regression_get_speakers_list,
test_regression_get_recording_speakers,
]
def main():
print(f"Running {len(ALL_TESTS)} Speaker API tests...\n")
passed = 0
failed = 0
errors = []
for test_fn in ALL_TESTS:
name = test_fn.__name__
try:
result = test_fn()
if result:
print(f" PASS {name}")
passed += 1
else:
print(f" FAIL {name} (returned False)")
failed += 1
errors.append(name)
except Exception as e:
print(f" ERROR {name}: {e}")
failed += 1
errors.append(name)
print(f"\n{'=' * 60}")
print(f"Results: {passed} passed, {failed} failed out of {len(ALL_TESTS)}")
if errors:
print("Failed tests:")
for e in errors:
print(f" - {e}")
print('=' * 60)
sys.exit(0 if failed == 0 else 1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
Integration test for API v1 recording upload endpoint.
Validates API token authentication and expected 400 response when no file is provided.
"""
import secrets
import sys
import os
# Add the parent directory to the path to import app
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.app import app, db
from src.models import User, APIToken
from src.utils.token_auth import hash_token
def _get_or_create_test_user():
user = User.query.filter_by(username="api_test_user").first()
created = False
if not user:
user = User(username="api_test_user", email="api_test_user@local.test")
db.session.add(user)
db.session.commit()
created = True
return user, created
def _create_api_token(user):
plaintext = f"test-token-{secrets.token_urlsafe(16)}"
token = APIToken(
user_id=user.id,
token_hash=hash_token(plaintext),
name="test-api-token"
)
db.session.add(token)
db.session.commit()
return token, plaintext
def test_upload_requires_file():
with app.app_context():
user, created_user = _get_or_create_test_user()
token_record, token = _create_api_token(user)
client = app.test_client()
try:
response = client.post(
"/api/v1/recordings/upload",
headers={"X-API-Token": token}
)
if response.status_code != 400:
print(f"❌ Expected 400, got {response.status_code}")
return False
payload = response.get_json(silent=True) or {}
if payload.get("error") != "No file provided":
print(f"❌ Unexpected error payload: {payload}")
return False
print("✅ Token auth works and missing file returns 400 as expected")
return True
finally:
db.session.delete(token_record)
db.session.commit()
if created_user:
db.session.delete(user)
db.session.commit()
def main():
print("🚀 Running API v1 upload test...\n")
ok = test_upload_requires_file()
print("\n" + ("✅ PASS" if ok else "❌ FAIL"))
sys.exit(0 if ok else 1)
if __name__ == "__main__":
main()

331
tests/test_audit.py Normal file
View File

@@ -0,0 +1,331 @@
#!/usr/bin/env python3
"""
Tests for the Loi 25 audit system.
Covers:
- audit_access(): adds to session, does NOT commit
- audit_login(): commits independently
- audit_failed_login(): commits, uses email_hash (not plain email)
- _is_recent_duplicate(): deduplication window
- get_access_logs() / get_auth_logs(): pagination and filters
- Admin API endpoints: /api/admin/audit/status, /access, /auth
"""
import os
import sys
import hashlib
# Add the parent directory to the path to import app
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ.setdefault('SQLALCHEMY_DATABASE_URI', 'sqlite:///:memory:')
os.environ.setdefault('ENABLE_AUDIT_LOG', 'true')
os.environ['ENABLE_AUDIT_LOG'] = 'true'
from src.app import app, db
from src.models import User
from src.models.access_log import AccessLog
from src.models.auth_log import AuthLog
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_user(username, is_admin=False):
user = User(username=username, email=f"{username}@test.local", is_admin=is_admin)
user.set_password("TestPass1!")
db.session.add(user)
db.session.commit()
return user
def _login_client(client, user):
"""Push a real Flask-Login session for a user via test request context."""
from flask_login import login_user
with client.session_transaction() as sess:
pass
# Use the test client's post to login
with app.test_request_context():
login_user(user)
# Directly inject the user_id into the session cookie
with client.session_transaction() as sess:
sess['_user_id'] = str(user.id)
sess['_fresh'] = True
# ---------------------------------------------------------------------------
# Service-level tests
# ---------------------------------------------------------------------------
def test_audit_access_no_commit():
"""audit_access() adds to session but does NOT commit."""
with app.app_context():
db.create_all()
user = _make_user("audit_no_commit")
try:
initial_count = AccessLog.query.count()
from src.services.audit import audit_access
with app.test_request_context():
from flask_login import login_user
login_user(user)
log = audit_access('edit', 'recording', 1, user_id=user.id)
# Log object returned but not yet in DB (no commit happened)
assert log is not None, "audit_access should return a log object"
# The session hasn't been committed, so count is still the same
# (SQLite in :memory: — flush to verify it's in session)
db.session.flush()
assert AccessLog.query.count() == initial_count + 1, "Log should be in session after flush"
db.session.rollback() # undo — simulates caller rollback
assert AccessLog.query.count() == initial_count, "Log should be gone after rollback"
print("✅ audit_access() does not commit")
return True
finally:
db.session.delete(user)
db.session.commit()
def test_audit_login_commits():
"""audit_login() commits its own transaction."""
with app.app_context():
db.create_all()
user = _make_user("audit_login_user")
try:
initial_count = AuthLog.query.count()
from src.services.audit import audit_login
with app.test_request_context():
audit_login(user.id)
# Should be committed — visible in a fresh query
assert AuthLog.query.count() == initial_count + 1, "auth log should be committed"
log = AuthLog.query.order_by(AuthLog.id.desc()).first()
assert log.action == 'login'
assert log.user_id == user.id
print("✅ audit_login() commits independently")
return True
finally:
AuthLog.query.filter_by(user_id=user.id).delete()
db.session.delete(user)
db.session.commit()
def test_audit_failed_login_uses_email_hash():
"""audit_failed_login() stores email_hash, not plain email."""
with app.app_context():
db.create_all()
initial_count = AuthLog.query.count()
email = "attacker-target@example.com"
email_hash = hashlib.sha256(email.lower().encode()).hexdigest()[:16]
from src.services.audit import audit_failed_login
with app.test_request_context():
audit_failed_login(details={'email_hash': email_hash, 'reason': 'wrong_password'})
assert AuthLog.query.count() == initial_count + 1
log = AuthLog.query.order_by(AuthLog.id.desc()).first()
assert log.action == 'failed_login'
assert log.details is not None
assert 'email_hash' in log.details, "Should store email_hash, not plain email"
assert 'email' not in log.details, "Should NOT store plain email"
assert log.details['email_hash'] == email_hash
# Cleanup
db.session.delete(log)
db.session.commit()
print("✅ audit_failed_login() stores email_hash, not plain email")
return True
def test_audit_view_deduplication():
"""audit_view() on the same resource within 5 min creates only one log entry."""
with app.app_context():
db.create_all()
user = _make_user("audit_dedup_user")
try:
initial_count = AccessLog.query.count()
from src.services.audit import audit_view
with app.test_request_context():
from flask_login import login_user
login_user(user)
# First view — should log
log1 = audit_view('recording', 42, user_id=user.id)
db.session.commit()
# Second view within 5 min — should be deduped
log2 = audit_view('recording', 42, user_id=user.id)
if log2 is not None:
db.session.commit()
count_after = AccessLog.query.filter_by(
user_id=user.id, action='view', resource_type='recording', resource_id=42
).count()
assert count_after == 1, f"Expected 1 log entry, got {count_after} (dedup failed)"
print("✅ audit_view() deduplication works (5-min window)")
return True
finally:
AccessLog.query.filter_by(user_id=user.id).delete()
db.session.delete(user)
db.session.commit()
def test_get_access_logs_pagination():
"""get_access_logs() returns paginated results."""
with app.app_context():
db.create_all()
user = _make_user("audit_pag_user")
try:
# Create 5 access log entries
for i in range(5):
log = AccessLog.log_access(
action='view', resource_type='recording', resource_id=i,
user_id=user.id, status='success',
)
db.session.add(log)
db.session.commit()
from src.services.audit import get_access_logs
page1 = get_access_logs(page=1, per_page=3, user_id=user.id)
assert page1.total >= 5
assert len(page1.items) == 3
page2 = get_access_logs(page=2, per_page=3, user_id=user.id)
assert len(page2.items) >= 2
print("✅ get_access_logs() pagination works")
return True
finally:
AccessLog.query.filter_by(user_id=user.id).delete()
db.session.delete(user)
db.session.commit()
# ---------------------------------------------------------------------------
# Admin API endpoint tests
# ---------------------------------------------------------------------------
def test_audit_status_requires_admin():
"""GET /api/admin/audit/status: 401 anon, 403 non-admin, 200 admin."""
with app.app_context():
db.create_all()
regular = _make_user("audit_status_regular")
admin = _make_user("audit_status_admin", is_admin=True)
client = app.test_client()
try:
# Anonymous — should redirect to login (302) or 401
resp = client.get('/api/admin/audit/status')
assert resp.status_code in (401, 302), f"Expected 401/302 for anon, got {resp.status_code}"
# Regular user — should get 403
_login_client(client, regular)
resp = client.get('/api/admin/audit/status')
assert resp.status_code == 403, f"Expected 403 for non-admin, got {resp.status_code}"
# Admin — should get 200
_login_client(client, admin)
resp = client.get('/api/admin/audit/status')
assert resp.status_code == 200, f"Expected 200 for admin, got {resp.status_code}"
data = resp.get_json()
assert 'enabled' in data
print("✅ /api/admin/audit/status access control works")
return True
finally:
db.session.delete(regular)
db.session.delete(admin)
db.session.commit()
def test_audit_access_logs_endpoint():
"""GET /api/admin/audit/access returns paginated logs for admin."""
with app.app_context():
db.create_all()
admin = _make_user("audit_access_ep_admin", is_admin=True)
client = app.test_client()
try:
_login_client(client, admin)
resp = client.get('/api/admin/audit/access?per_page=10')
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.get_json()
assert 'logs' in data
assert 'total' in data
assert 'page' in data
print("✅ /api/admin/audit/access returns correct structure")
return True
finally:
db.session.delete(admin)
db.session.commit()
def test_audit_auth_logs_endpoint():
"""GET /api/admin/audit/auth returns paginated auth logs for admin."""
with app.app_context():
db.create_all()
admin = _make_user("audit_auth_ep_admin", is_admin=True)
client = app.test_client()
try:
_login_client(client, admin)
resp = client.get('/api/admin/audit/auth?per_page=10')
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.get_json()
assert 'logs' in data
assert 'total' in data
print("✅ /api/admin/audit/auth returns correct structure")
return True
finally:
db.session.delete(admin)
db.session.commit()
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
def main():
print("🚀 Running audit system tests...\n")
tests = [
test_audit_access_no_commit,
test_audit_login_commits,
test_audit_failed_login_uses_email_hash,
test_audit_view_deduplication,
test_get_access_logs_pagination,
test_audit_status_requires_admin,
test_audit_access_logs_endpoint,
test_audit_auth_logs_endpoint,
]
passed = 0
failed = 0
for test in tests:
try:
result = test()
if result:
passed += 1
else:
print(f"{test.__name__} returned False")
failed += 1
except Exception as e:
print(f"{test.__name__} raised: {e}")
import traceback
traceback.print_exc()
failed += 1
print(f"\n{'='*40}")
print(f"Results: {passed} passed, {failed} failed")
print("✅ ALL PASS" if failed == 0 else "❌ SOME FAILED")
sys.exit(0 if failed == 0 else 1)
if __name__ == "__main__":
main()

239
tests/test_bugfixes.py Normal file
View File

@@ -0,0 +1,239 @@
#!/usr/bin/env python3
"""
Tests for specific bug fixes.
- Issue #230: Bulk delete crash when recordings have speaker_snippets
- Issue #223: File monitor stability time env var
"""
import json
import os
import sys
import time
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.app import app, db
from src.models import User, Recording
from src.models.speaker_snippet import SpeakerSnippet
# Disable CSRF for testing
app.config['WTF_CSRF_ENABLED'] = False
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_or_create_user():
user = User.query.filter_by(username="bugfix_test_user").first()
if not user:
user = User(username="bugfix_test_user", email="bugfix@local.test")
db.session.add(user)
db.session.commit()
return user
def _create_recording_with_snippets(user):
"""Create a recording that has speaker_snippet records attached."""
rec = Recording(
user_id=user.id,
title="Recording with snippets",
status="COMPLETED",
transcription=json.dumps([
{"speaker": "SPEAKER_00", "sentence": "Hello there."},
]),
)
db.session.add(rec)
db.session.commit()
# We need a speaker to attach snippets to
from src.models import Speaker
speaker = Speaker.query.filter_by(user_id=user.id, name="BugfixTestSpeaker").first()
if not speaker:
speaker = Speaker(name="BugfixTestSpeaker", user_id=user.id)
db.session.add(speaker)
db.session.commit()
snippet = SpeakerSnippet(
speaker_id=speaker.id,
recording_id=rec.id,
segment_index=0,
text_snippet="Hello there.",
)
db.session.add(snippet)
db.session.commit()
return rec, speaker, snippet
# ---------------------------------------------------------------------------
# Issue #230: Deleting recordings with speaker_snippets
# ---------------------------------------------------------------------------
class TestIssue230BulkDeleteCascade:
"""Verify that deleting a recording with speaker_snippets doesn't crash."""
def test_single_delete_with_snippets(self):
"""Single DELETE /recording/<id> should succeed when snippets exist."""
with app.app_context():
user = _get_or_create_user()
rec, speaker, snippet = _create_recording_with_snippets(user)
rec_id = rec.id
snippet_id = snippet.id
with app.test_client() as client:
# Login
with client.session_transaction() as sess:
sess['_user_id'] = str(user.id)
resp = client.delete(f'/recording/{rec_id}')
assert resp.status_code == 200, f"Delete failed: {resp.get_json()}"
data = resp.get_json()
assert data.get('success') is True
# Verify snippet was also deleted
orphan = db.session.get(SpeakerSnippet, snippet_id)
assert orphan is None, "Speaker snippet should have been deleted with recording"
# Cleanup speaker
db.session.delete(speaker)
db.session.commit()
def test_bulk_delete_with_snippets(self):
"""DELETE /api/recordings/bulk should succeed when snippets exist."""
with app.app_context():
user = _get_or_create_user()
rec, speaker, snippet = _create_recording_with_snippets(user)
rec_id = rec.id
snippet_id = snippet.id
with app.test_client() as client:
with client.session_transaction() as sess:
sess['_user_id'] = str(user.id)
resp = client.delete(
'/api/recordings/bulk',
json={'recording_ids': [rec_id]},
content_type='application/json',
)
assert resp.status_code == 200, f"Bulk delete failed: {resp.get_json()}"
data = resp.get_json()
assert data.get('success') is True
assert rec_id in data.get('deleted_ids', [])
# Verify snippet was also deleted
orphan = db.session.get(SpeakerSnippet, snippet_id)
assert orphan is None, "Speaker snippet should have been deleted with recording"
# Cleanup speaker
db.session.delete(speaker)
db.session.commit()
def test_bulk_delete_multiple_with_snippets(self):
"""Bulk deleting multiple recordings (some with snippets) should succeed."""
with app.app_context():
user = _get_or_create_user()
rec1, speaker, snippet = _create_recording_with_snippets(user)
rec2 = Recording(user_id=user.id, title="No snippets", status="COMPLETED")
db.session.add(rec2)
db.session.commit()
rec1_id, rec2_id = rec1.id, rec2.id
with app.test_client() as client:
with client.session_transaction() as sess:
sess['_user_id'] = str(user.id)
resp = client.delete(
'/api/recordings/bulk',
json={'recording_ids': [rec1_id, rec2_id]},
content_type='application/json',
)
assert resp.status_code == 200, f"Bulk delete failed: {resp.get_json()}"
data = resp.get_json()
assert data.get('deleted_count') == 2
# Cleanup speaker
db.session.delete(speaker)
db.session.commit()
# ---------------------------------------------------------------------------
# Issue #223: File monitor stability time
# ---------------------------------------------------------------------------
class TestIssue223StabilityTime:
"""Verify AUTO_PROCESS_STABILITY_TIME env var is respected."""
def test_default_stability_time(self):
"""Without env var, stability_time defaults to 5."""
from src.file_monitor import FileMonitor
monitor = FileMonitor.__new__(FileMonitor)
monitor.logger = MagicMock()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
f.write(b'fake audio data')
tmp_path = Path(f.name)
try:
with patch.dict(os.environ, {}, clear=False):
# Remove the env var if it exists
os.environ.pop('AUTO_PROCESS_STABILITY_TIME', None)
with patch('time.sleep') as mock_sleep:
monitor._is_file_stable(tmp_path)
mock_sleep.assert_called_once_with(5)
finally:
tmp_path.unlink(missing_ok=True)
def test_custom_stability_time(self):
"""AUTO_PROCESS_STABILITY_TIME=15 should sleep for 15 seconds."""
from src.file_monitor import FileMonitor
monitor = FileMonitor.__new__(FileMonitor)
monitor.logger = MagicMock()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
f.write(b'fake audio data')
tmp_path = Path(f.name)
try:
with patch.dict(os.environ, {'AUTO_PROCESS_STABILITY_TIME': '15'}):
with patch('time.sleep') as mock_sleep:
# _is_file_stable uses the default param, but the caller reads env
# So we test the caller path via _scan_user_directory indirectly
# or just call with explicit value
stability_time = int(os.environ.get('AUTO_PROCESS_STABILITY_TIME', '5'))
monitor._is_file_stable(tmp_path, stability_time)
mock_sleep.assert_called_once_with(15)
finally:
tmp_path.unlink(missing_ok=True)
def test_no_hardcoded_cap(self):
"""Stability time should NOT be capped at 2 seconds anymore."""
from src.file_monitor import FileMonitor
monitor = FileMonitor.__new__(FileMonitor)
monitor.logger = MagicMock()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
f.write(b'fake audio data')
tmp_path = Path(f.name)
try:
with patch('time.sleep') as mock_sleep:
monitor._is_file_stable(tmp_path, stability_time=30)
# Should sleep for 30, NOT min(30, 2) = 2
mock_sleep.assert_called_once_with(30)
finally:
tmp_path.unlink(missing_ok=True)
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import pytest
sys.exit(pytest.main([__file__, "-v"]))

View File

@@ -0,0 +1,564 @@
#!/usr/bin/env python3
"""
Test script for the transcription connector architecture.
This script tests:
1. Connector auto-detection from environment variables
2. Backwards compatibility with legacy config
3. Connector specifications and capabilities
4. Chunking logic (connector-aware)
5. Codec handling per connector
6. Request/Response data types
Run with: docker exec speakr-dev python /app/tests/test_connector_architecture.py
"""
import os
import sys
import io
import json
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
# Test results tracking
PASSED = 0
FAILED = 0
ERRORS = []
def run_test(name, func):
"""Run a test function and track results."""
global PASSED, FAILED, ERRORS
try:
func()
print(f"{name}")
PASSED += 1
except AssertionError as e:
print(f"{name}: {e}")
FAILED += 1
ERRORS.append((name, str(e)))
except Exception as e:
print(f"{name}: EXCEPTION - {e}")
FAILED += 1
ERRORS.append((name, f"Exception: {e}"))
def clear_env():
"""Clear all transcription-related environment variables."""
keys_to_clear = [
'TRANSCRIPTION_CONNECTOR', 'TRANSCRIPTION_API_KEY', 'TRANSCRIPTION_BASE_URL',
'TRANSCRIPTION_MODEL', 'WHISPER_MODEL', 'USE_ASR_ENDPOINT', 'ASR_BASE_URL',
'ASR_DIARIZE', 'ASR_RETURN_SPEAKER_EMBEDDINGS', 'ASR_TIMEOUT',
'ASR_MIN_SPEAKERS', 'ASR_MAX_SPEAKERS', 'ENABLE_CHUNKING', 'CHUNK_LIMIT',
'CHUNK_OVERLAP_SECONDS', 'AUDIO_UNSUPPORTED_CODECS',
]
for key in keys_to_clear:
os.environ.pop(key, None)
def reset_registry():
"""Reset the connector registry singleton."""
from src.services.transcription import registry
registry._registry = None
registry.ConnectorRegistry._instance = None
registry.ConnectorRegistry._initialized = False
registry.ConnectorRegistry._active_connector = None
registry.ConnectorRegistry._connector_name = ""
# =============================================================================
# TEST SECTION 1: Base Classes and Data Types
# =============================================================================
def test_base_classes():
"""Test base classes and data types."""
print("\n=== Testing Base Classes ===")
from src.services.transcription.base import (
TranscriptionCapability, ConnectorSpecifications, TranscriptionRequest,
TranscriptionResponse, TranscriptionSegment,
)
def t1():
assert TranscriptionCapability.DIARIZATION is not None
assert TranscriptionCapability.TIMESTAMPS is not None
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL is not None
run_test("TranscriptionCapability enum has expected values", t1)
def t2():
specs = ConnectorSpecifications()
assert specs.max_file_size_bytes is None
assert specs.handles_chunking_internally is False
assert specs.recommended_chunk_seconds == 600
run_test("ConnectorSpecifications has correct defaults", t2)
def t3():
specs = ConnectorSpecifications(
max_file_size_bytes=25 * 1024 * 1024,
handles_chunking_internally=True,
unsupported_codecs=frozenset({'opus'})
)
assert specs.max_file_size_bytes == 25 * 1024 * 1024
assert 'opus' in specs.unsupported_codecs
run_test("ConnectorSpecifications with custom values", t3)
def t4():
audio = io.BytesIO(b"fake audio data")
request = TranscriptionRequest(audio_file=audio, filename="test.wav", diarize=True)
assert request.filename == "test.wav"
assert request.diarize is True
run_test("TranscriptionRequest creation", t4)
def t5():
segments = [
TranscriptionSegment(text="Hello", speaker="SPEAKER_00", start_time=0.0, end_time=1.0),
TranscriptionSegment(text="World", speaker="SPEAKER_01", start_time=1.0, end_time=2.0),
]
response = TranscriptionResponse(text="Hello World", segments=segments, provider="test")
storage = response.to_storage_format()
data = json.loads(storage)
assert len(data) == 2
assert data[0]['speaker'] == "SPEAKER_00"
run_test("TranscriptionResponse to_storage_format", t5)
def t6():
segments = [TranscriptionSegment(text="Hello", speaker="SPEAKER_00")]
response = TranscriptionResponse(text="Hello", segments=segments)
assert response.has_diarization() is True
response2 = TranscriptionResponse(text="Hello", segments=None)
assert response2.has_diarization() is False
run_test("TranscriptionResponse has_diarization", t6)
# =============================================================================
# TEST SECTION 2: Connector Auto-Detection
# =============================================================================
def test_auto_detection():
"""Test connector auto-detection from environment variables."""
print("\n=== Testing Connector Auto-Detection ===")
from src.services.transcription.registry import get_registry
def t1():
clear_env()
reset_registry()
os.environ['TRANSCRIPTION_CONNECTOR'] = 'openai_whisper'
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
os.environ['ASR_BASE_URL'] = 'http://should-be-ignored:9000'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'openai_whisper'
run_test("Explicit TRANSCRIPTION_CONNECTOR takes priority", t1)
def t2():
clear_env()
reset_registry()
os.environ['ASR_BASE_URL'] = 'http://whisperx:9000'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'asr_endpoint'
run_test("ASR_BASE_URL auto-detects asr_endpoint", t2)
def t3():
clear_env()
reset_registry()
os.environ['USE_ASR_ENDPOINT'] = 'true'
os.environ['ASR_BASE_URL'] = 'http://whisperx:9000'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'asr_endpoint'
run_test("Legacy USE_ASR_ENDPOINT=true still works", t3)
def t4():
clear_env()
reset_registry()
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
os.environ['TRANSCRIPTION_MODEL'] = 'gpt-4o-transcribe-diarize'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'openai_transcribe'
run_test("gpt-4o model auto-detects openai_transcribe", t4)
def t5():
clear_env()
reset_registry()
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
os.environ['TRANSCRIPTION_MODEL'] = 'whisper-1'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'openai_whisper'
run_test("whisper-1 model uses openai_whisper", t5)
def t6():
clear_env()
reset_registry()
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
os.environ['WHISPER_MODEL'] = 'whisper-1'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'openai_whisper'
run_test("Legacy WHISPER_MODEL still works", t6)
def t7():
clear_env()
reset_registry()
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'openai_whisper'
run_test("Default falls back to openai_whisper", t7)
# =============================================================================
# TEST SECTION 3: Connector Specifications
# =============================================================================
def test_connector_specifications():
"""Test connector specifications are correctly defined."""
print("\n=== Testing Connector Specifications ===")
from src.services.transcription.connectors.openai_whisper import OpenAIWhisperConnector
from src.services.transcription.connectors.openai_transcribe import OpenAITranscribeConnector
from src.services.transcription.connectors.asr_endpoint import ASREndpointConnector
from src.services.transcription.base import TranscriptionCapability
def t1():
specs = OpenAIWhisperConnector.SPECIFICATIONS
assert specs.max_file_size_bytes == 25 * 1024 * 1024
assert specs.handles_chunking_internally is False
run_test("OpenAI Whisper has 25MB limit", t1)
def t2():
specs = OpenAIWhisperConnector.SPECIFICATIONS
assert specs.unsupported_codecs is not None
assert 'opus' in specs.unsupported_codecs
run_test("OpenAI Whisper declares opus as unsupported", t2)
def t3():
specs = OpenAITranscribeConnector.SPECIFICATIONS
assert specs.handles_chunking_internally is True
assert specs.requires_chunking_param is True
run_test("OpenAI Transcribe has internal chunking", t3)
def t4():
specs = ASREndpointConnector.SPECIFICATIONS
assert specs.max_file_size_bytes is None
assert specs.handles_chunking_internally is True
run_test("ASR Endpoint has no limits (handles internally)", t4)
def t5():
assert TranscriptionCapability.DIARIZATION not in OpenAIWhisperConnector.CAPABILITIES
run_test("OpenAI Whisper does NOT support diarization", t5)
def t6():
# Diarization is added dynamically based on model at instance level
connector = OpenAITranscribeConnector({'api_key': 'test', 'model': 'gpt-4o-transcribe-diarize'})
assert TranscriptionCapability.DIARIZATION in connector.CAPABILITIES
assert connector.supports_diarization is True
run_test("OpenAI Transcribe with diarize model supports diarization", t6)
def t7():
assert TranscriptionCapability.DIARIZATION in ASREndpointConnector.CAPABILITIES
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL in ASREndpointConnector.CAPABILITIES
run_test("ASR Endpoint supports diarization and speaker count control", t7)
def t8():
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL not in OpenAIWhisperConnector.CAPABILITIES
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL not in OpenAITranscribeConnector.CAPABILITIES
run_test("OpenAI connectors do NOT support speaker count control", t8)
# =============================================================================
# TEST SECTION 4: Chunking Logic
# =============================================================================
def test_chunking_logic():
"""Test connector-aware chunking logic."""
print("\n=== Testing Chunking Logic ===")
from src.audio_chunking import get_effective_chunking_config
from src.services.transcription.base import ConnectorSpecifications
def t1():
clear_env()
os.environ['ENABLE_CHUNKING'] = 'true'
os.environ['CHUNK_LIMIT'] = '20MB'
specs = ConnectorSpecifications(handles_chunking_internally=True)
config = get_effective_chunking_config(specs)
assert config.enabled is False
assert config.source == 'connector_internal'
run_test("Connector with internal chunking disables app chunking", t1)
def t2():
clear_env()
os.environ['ENABLE_CHUNKING'] = 'true'
os.environ['CHUNK_LIMIT'] = '15MB'
os.environ['CHUNK_OVERLAP_SECONDS'] = '5'
specs = ConnectorSpecifications(handles_chunking_internally=False)
config = get_effective_chunking_config(specs)
assert config.enabled is True
assert config.source == 'env'
assert config.mode == 'size'
assert config.limit_value == 15.0
run_test("Connector without internal chunking uses ENV settings", t2)
def t3():
clear_env()
os.environ['ENABLE_CHUNKING'] = 'false'
specs = ConnectorSpecifications(handles_chunking_internally=False)
config = get_effective_chunking_config(specs)
assert config.enabled is False
assert config.source == 'disabled'
run_test("ENABLE_CHUNKING=false disables chunking", t3)
def t4():
clear_env()
os.environ['ENABLE_CHUNKING'] = 'true'
os.environ['CHUNK_LIMIT'] = '10m'
specs = ConnectorSpecifications(handles_chunking_internally=False)
config = get_effective_chunking_config(specs)
assert config.enabled is True
assert config.mode == 'duration'
assert config.limit_value == 600.0
run_test("Duration-based chunk limit parsing (10m = 600s)", t4)
# =============================================================================
# TEST SECTION 5: Codec Handling
# =============================================================================
def test_codec_handling():
"""Test codec handling with connector specifications."""
print("\n=== Testing Codec Handling ===")
from src.services.transcription.base import ConnectorSpecifications
def reload_audio_module():
"""Properly reload audio_conversion module with fresh env vars."""
import sys
# Remove relevant modules from cache to force fresh import
# app_config reads AUDIO_UNSUPPORTED_CODECS at import time
for mod_name in list(sys.modules.keys()):
if mod_name.startswith('src.utils') or mod_name.startswith('src.config'):
del sys.modules[mod_name]
from src.utils import audio_conversion
return audio_conversion
def t1():
clear_env()
mod = reload_audio_module()
codecs = mod.get_supported_codecs()
assert 'mp3' in codecs
assert 'flac' in codecs
run_test("Default supported codecs include common formats", t1)
def t2():
clear_env()
mod = reload_audio_module()
specs = ConnectorSpecifications(unsupported_codecs=frozenset({'opus', 'vorbis'}))
codecs = mod.get_supported_codecs(connector_specs=specs)
assert 'opus' not in codecs
assert 'vorbis' not in codecs
assert 'mp3' in codecs
run_test("Connector unsupported_codecs removes from defaults", t2)
def t3():
clear_env()
os.environ['AUDIO_UNSUPPORTED_CODECS'] = 'aac,opus'
mod = reload_audio_module()
codecs = mod.get_supported_codecs()
assert 'aac' not in codecs, f"aac should not be in {codecs}"
assert 'opus' not in codecs, f"opus should not be in {codecs}"
run_test("AUDIO_UNSUPPORTED_CODECS env var still works", t3)
def t4():
clear_env()
os.environ['AUDIO_UNSUPPORTED_CODECS'] = 'aac'
mod = reload_audio_module()
specs = ConnectorSpecifications(unsupported_codecs=frozenset({'opus'}))
codecs = mod.get_supported_codecs(connector_specs=specs)
assert 'aac' not in codecs, f"aac should not be in {codecs}"
assert 'opus' not in codecs, f"opus should not be in {codecs}"
assert 'mp3' in codecs
run_test("Both connector specs and ENV var work together", t4)
# =============================================================================
# TEST SECTION 6: Connector Capabilities
# =============================================================================
def test_connector_capabilities():
"""Test connector capabilities are exposed correctly."""
print("\n=== Testing Connector Capabilities ===")
from src.services.transcription.connectors.asr_endpoint import ASREndpointConnector
from src.services.transcription.connectors.openai_transcribe import OpenAITranscribeConnector
from src.services.transcription.base import TranscriptionCapability
def t1():
connector = ASREndpointConnector({'base_url': 'http://test:9000'})
assert connector.supports_diarization is True
run_test("ASR connector supports_diarization property", t1)
def t2():
connector = ASREndpointConnector({'base_url': 'http://test:9000'})
assert connector.supports_speaker_count_control is True
run_test("ASR connector supports_speaker_count_control property", t2)
def t3():
connector = OpenAITranscribeConnector({'api_key': 'test-key', 'model': 'gpt-4o-transcribe-diarize'})
assert connector.supports_diarization is True
assert connector.supports_speaker_count_control is False
run_test("OpenAI Transcribe supports diarization but not speaker_count_control", t3)
def t4():
connector = ASREndpointConnector({'base_url': 'http://test:9000'})
assert connector.supports(TranscriptionCapability.DIARIZATION) is True
assert connector.supports(TranscriptionCapability.STREAMING) is False
run_test("supports() method works correctly", t4)
# =============================================================================
# TEST SECTION 7: Registry Operations
# =============================================================================
def test_registry_operations():
"""Test registry listing and connector info."""
print("\n=== Testing Registry Operations ===")
from src.services.transcription.registry import get_registry
def t1():
clear_env()
reset_registry()
registry = get_registry()
connectors = registry.list_connectors()
names = [c['name'] for c in connectors]
assert 'openai_whisper' in names
assert 'openai_transcribe' in names
assert 'asr_endpoint' in names
run_test("Registry lists all built-in connectors", t1)
def t2():
clear_env()
reset_registry()
registry = get_registry()
connectors = registry.list_connectors()
asr = next(c for c in connectors if c['name'] == 'asr_endpoint')
assert 'DIARIZATION' in asr['capabilities']
assert 'SPEAKER_COUNT_CONTROL' in asr['capabilities']
run_test("Connector info includes capabilities", t2)
def t3():
clear_env()
reset_registry()
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
os.environ['TRANSCRIPTION_MODEL'] = 'whisper-1'
registry = get_registry()
registry.initialize_from_env()
assert registry.get_active_connector_name() == 'openai_whisper'
os.environ['TRANSCRIPTION_MODEL'] = 'gpt-4o-transcribe-diarize'
registry.reinitialize()
assert registry.get_active_connector_name() == 'openai_transcribe'
run_test("reinitialize() resets the active connector", t3)
# =============================================================================
# TEST SECTION 8: Edge Cases
# =============================================================================
def test_edge_cases():
"""Test edge cases and error handling."""
print("\n=== Testing Edge Cases ===")
from src.services.transcription.registry import get_registry
from src.services.transcription.exceptions import ConfigurationError
from src.services.transcription.base import TranscriptionResponse, TranscriptionSegment
def t1():
# Empty segments list returns the text (empty string), not "[]"
response = TranscriptionResponse(text="", segments=[], provider="test")
assert response.to_storage_format() == ""
assert response.has_diarization() is False
run_test("Empty transcription response handling", t1)
def t2():
segments = [TranscriptionSegment(text="Hello", speaker=None)]
response = TranscriptionResponse(text="Hello", segments=segments)
storage = response.to_storage_format()
data = json.loads(storage)
assert data[0]['speaker'] == 'Unknown Speaker'
run_test("Transcription with unknown speaker handling", t2)
def t3():
clear_env()
reset_registry()
os.environ['TRANSCRIPTION_CONNECTOR'] = 'nonexistent_connector'
registry = get_registry()
try:
registry.initialize_from_env()
assert False, "Should have raised ConfigurationError"
except ConfigurationError as e:
assert 'Unknown connector' in str(e)
run_test("Invalid connector name raises ConfigurationError", t3)
def t4():
from src.services.transcription.connectors.asr_endpoint import ASREndpointConnector
try:
ASREndpointConnector({})
assert False, "Should have raised ConfigurationError"
except ConfigurationError as e:
assert 'base_url' in str(e)
run_test("ASR connector validates base_url is required", t4)
def t5():
clear_env()
reset_registry()
os.environ['ASR_BASE_URL'] = 'http://whisperx:9000 # This is a comment'
registry = get_registry()
connector = registry.initialize_from_env()
assert connector.base_url == 'http://whisperx:9000'
run_test("ASR_BASE_URL with trailing comment is handled", t5)
# =============================================================================
# Main
# =============================================================================
def main():
"""Run all tests."""
global PASSED, FAILED, ERRORS
print("=" * 60)
print("Transcription Connector Architecture Tests")
print("=" * 60)
test_base_classes()
test_auto_detection()
test_connector_specifications()
test_chunking_logic()
test_codec_handling()
test_connector_capabilities()
test_registry_operations()
test_edge_cases()
print("\n" + "=" * 60)
print(f"RESULTS: {PASSED} passed, {FAILED} failed")
print("=" * 60)
if ERRORS:
print("\nFailed tests:")
for name, error in ERRORS:
print(f" - {name}: {error}")
clear_env()
return 0 if FAILED == 0 else 1
if __name__ == '__main__':
sys.exit(main())

View File

@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Test script for ffprobe codec detection functionality.
This script tests the new codec-based detection system to ensure it correctly
identifies audio codecs, video files, and lossless formats.
"""
import os
import sys
import tempfile
import subprocess
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.utils.ffprobe import (
get_codec_info,
is_video_file,
is_audio_file,
get_audio_codec,
needs_audio_conversion,
is_lossless_audio,
get_duration,
FFProbeError
)
def create_test_audio_file(codec, output_path, duration=1.0):
"""Create a test audio file with specific codec."""
codec_map = {
'mp3': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'libmp3lame', '-b:a', '128k', output_path],
'aac': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'aac', '-b:a', '128k', output_path],
'opus': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'libopus', '-b:a', '64k', output_path],
'flac': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'flac', output_path],
'pcm_s16le': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'pcm_s16le', '-ar', '44100', output_path],
'vorbis': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'libvorbis', '-b:a', '128k', output_path],
}
if codec not in codec_map:
raise ValueError(f"Unknown codec: {codec}")
subprocess.run(codec_map[codec], check=True, capture_output=True)
def create_test_video_file(output_path, duration=1.0):
"""Create a test video file with audio."""
subprocess.run([
'ffmpeg', '-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=320x240:rate=1',
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
'-acodec', 'aac', '-vcodec', 'libx264', '-pix_fmt', 'yuv420p',
output_path
], check=True, capture_output=True)
def test_codec_detection():
"""Test basic codec detection."""
print("\n=== Testing Codec Detection ===\n")
with tempfile.TemporaryDirectory() as tmpdir:
test_files = {
'mp3': 'test.mp3',
'aac': 'test.m4a',
'opus': 'test.opus',
'flac': 'test.flac',
'pcm_s16le': 'test.wav',
'vorbis': 'test.ogg',
}
for codec, filename in test_files.items():
filepath = os.path.join(tmpdir, filename)
try:
print(f"Creating test file: {filename} with codec {codec}...")
create_test_audio_file(codec, filepath)
print(f" Probing {filename}...")
codec_info = get_codec_info(filepath)
detected_codec = codec_info['audio_codec']
print(f" ✓ Detected codec: {detected_codec}")
print(f" Has audio: {codec_info['has_audio']}")
print(f" Has video: {codec_info['has_video']}")
print(f" Format: {codec_info['format_name']}")
print(f" Duration: {codec_info['duration']:.2f}s" if codec_info['duration'] else " Duration: N/A")
if detected_codec != codec:
print(f" ⚠️ Warning: Expected {codec}, got {detected_codec}")
print()
except Exception as e:
print(f" ✗ Failed to test {codec}: {e}\n")
def test_video_detection():
"""Test video file detection."""
print("\n=== Testing Video Detection ===\n")
with tempfile.TemporaryDirectory() as tmpdir:
video_path = os.path.join(tmpdir, 'test_video.mp4')
audio_path = os.path.join(tmpdir, 'test_audio.mp3')
try:
print("Creating test video file...")
create_test_video_file(video_path)
print("Creating test audio file...")
create_test_audio_file('mp3', audio_path)
print(f"\nProbing video file...")
codec_info = get_codec_info(video_path)
print(f" Audio codec: {codec_info['audio_codec']}")
print(f" Video codec: {codec_info['video_codec']}")
print(f" Has audio: {codec_info['has_audio']}")
print(f" Has video: {codec_info['has_video']}")
is_video = is_video_file(video_path)
print(f" is_video_file(): {is_video}")
if not is_video:
print(" ✗ Video file not detected as video!")
else:
print(" ✓ Video file correctly detected")
print(f"\nProbing audio file...")
codec_info = get_codec_info(audio_path)
print(f" Audio codec: {codec_info['audio_codec']}")
print(f" Video codec: {codec_info['video_codec']}")
print(f" Has audio: {codec_info['has_audio']}")
print(f" Has video: {codec_info['has_video']}")
is_video = is_video_file(audio_path)
print(f" is_video_file(): {is_video}")
if is_video:
print(" ✗ Audio file incorrectly detected as video!")
else:
print(" ✓ Audio file correctly identified as audio-only")
print()
except Exception as e:
print(f"✗ Failed to test video detection: {e}\n")
def test_lossless_detection():
"""Test lossless audio detection."""
print("\n=== Testing Lossless Detection ===\n")
with tempfile.TemporaryDirectory() as tmpdir:
test_cases = {
'pcm_s16le': ('test.wav', True),
'flac': ('test.flac', True),
'mp3': ('test.mp3', False),
'aac': ('test.m4a', False),
'opus': ('test.opus', False),
}
for codec, (filename, expected_lossless) in test_cases.items():
filepath = os.path.join(tmpdir, filename)
try:
print(f"Creating {filename} with codec {codec}...")
create_test_audio_file(codec, filepath)
is_lossless = is_lossless_audio(filepath)
status = "" if is_lossless == expected_lossless else ""
print(f" {status} {codec}: is_lossless={is_lossless} (expected {expected_lossless})")
except Exception as e:
print(f" ✗ Failed to test {codec}: {e}")
print()
def test_conversion_check():
"""Test conversion requirement detection."""
print("\n=== Testing Conversion Check ===\n")
with tempfile.TemporaryDirectory() as tmpdir:
# Supported codecs for direct transcription
supported_codecs = ['pcm_s16le', 'mp3', 'flac', 'opus', 'aac']
test_cases = {
'mp3': ('test.mp3', False), # Supported, no conversion needed
'aac': ('test.m4a', False), # Supported, no conversion needed
'opus': ('test.opus', False), # Supported, no conversion needed
'vorbis': ('test.ogg', True), # Not in supported list, needs conversion
}
for codec, (filename, should_convert) in test_cases.items():
filepath = os.path.join(tmpdir, filename)
try:
print(f"Creating {filename} with codec {codec}...")
create_test_audio_file(codec, filepath)
needs_conversion, detected_codec = needs_audio_conversion(filepath, supported_codecs)
status = "" if needs_conversion == should_convert else ""
print(f" {status} {codec}: needs_conversion={needs_conversion} (expected {should_convert})")
print(f" Detected codec: {detected_codec}")
except Exception as e:
print(f" ✗ Failed to test {codec}: {e}")
print()
def test_misnamed_file():
"""Test detection of files with wrong extensions."""
print("\n=== Testing Misnamed File Detection ===\n")
with tempfile.TemporaryDirectory() as tmpdir:
# Create an MP3 file but name it .wav
wrong_name_path = os.path.join(tmpdir, 'actually_mp3.wav')
try:
print("Creating MP3 file with .wav extension...")
create_test_audio_file('mp3', wrong_name_path)
codec_info = get_codec_info(wrong_name_path)
detected_codec = codec_info['audio_codec']
print(f" Filename: actually_mp3.wav")
print(f" Detected codec: {detected_codec}")
if detected_codec == 'mp3':
print(" ✓ Correctly detected MP3 codec despite .wav extension")
else:
print(f" ✗ Incorrectly detected as {detected_codec}")
# Create a FLAC file but name it .mp3
wrong_name_path2 = os.path.join(tmpdir, 'actually_flac.mp3')
print("\nCreating FLAC file with .mp3 extension...")
create_test_audio_file('flac', wrong_name_path2)
codec_info = get_codec_info(wrong_name_path2)
detected_codec = codec_info['audio_codec']
print(f" Filename: actually_flac.mp3")
print(f" Detected codec: {detected_codec}")
if detected_codec == 'flac':
print(" ✓ Correctly detected FLAC codec despite .mp3 extension")
else:
print(f" ✗ Incorrectly detected as {detected_codec}")
print()
except Exception as e:
print(f"✗ Failed to test misnamed files: {e}\n")
def test_duration():
"""Test duration extraction."""
print("\n=== Testing Duration Extraction ===\n")
with tempfile.TemporaryDirectory() as tmpdir:
durations = [1.0, 2.5, 5.0]
for expected_duration in durations:
filepath = os.path.join(tmpdir, f'test_{expected_duration}s.mp3')
try:
print(f"Creating {expected_duration}s audio file...")
create_test_audio_file('mp3', filepath, duration=expected_duration)
detected_duration = get_duration(filepath)
# Allow 0.1s tolerance for encoding variations
if detected_duration and abs(detected_duration - expected_duration) < 0.1:
print(f" ✓ Duration: {detected_duration:.2f}s (expected {expected_duration}s)")
else:
print(f" ✗ Duration: {detected_duration:.2f}s (expected {expected_duration}s)")
except Exception as e:
print(f" ✗ Failed to test duration: {e}")
print()
def main():
"""Run all tests."""
print("=" * 60)
print("FFProbe Codec Detection Test Suite")
print("=" * 60)
# Check if ffmpeg/ffprobe are available
try:
subprocess.run(['ffprobe', '-version'], capture_output=True, check=True)
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
except (FileNotFoundError, subprocess.CalledProcessError):
print("\n✗ Error: ffmpeg/ffprobe not found. Please install ffmpeg to run tests.\n")
return 1
try:
test_codec_detection()
test_video_detection()
test_lossless_detection()
test_conversion_check()
test_misnamed_file()
test_duration()
print("=" * 60)
print("All tests completed!")
print("=" * 60)
print()
return 0
except Exception as e:
print(f"\n✗ Test suite failed with error: {e}\n")
import traceback
traceback.print_exc()
return 1
if __name__ == '__main__':
sys.exit(main())

241
tests/test_hotwords.sh Executable file
View File

@@ -0,0 +1,241 @@
#!/bin/bash
# test_hotwords.sh - Test hotwords and initial_prompt features
#
# Usage:
# ./tests/test_hotwords.sh <asr_url> <audio_file>
# ./tests/test_hotwords.sh http://localhost:9000 temp/Recording\ 4.flac
#
# The audio should contain domain-specific words that Whisper tends to
# misspell (brand names, acronyms, unusual proper nouns). The script runs
# three transcriptions and compares the results.
set -euo pipefail
ASR_URL="${1:-http://localhost:9000}"
AUDIO_FILE="${2:-temp/Recording 4.flac}"
OUTPUT_DIR="temp/hotwords_test_results"
# Colors
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
RED='\033[0;31m'
NC='\033[0m' # No Color
# Configurable test values - adjust these for your audio
HOTWORDS="Speakr,CTranslate2,PyAnnote,WhisperX"
INITIAL_PROMPT="This is a meeting about AI-powered audio transcription tools including Speakr, CTranslate2, and PyAnnote."
echo -e "${CYAN}============================================${NC}"
echo -e "${CYAN} Hotwords & Initial Prompt Test Suite${NC}"
echo -e "${CYAN}============================================${NC}"
echo ""
echo -e "ASR URL: ${YELLOW}${ASR_URL}${NC}"
echo -e "Audio file: ${YELLOW}${AUDIO_FILE}${NC}"
echo -e "Hotwords: ${YELLOW}${HOTWORDS}${NC}"
echo -e "Initial prompt: ${YELLOW}${INITIAL_PROMPT}${NC}"
echo ""
# Verify audio file exists
if [ ! -f "$AUDIO_FILE" ]; then
echo -e "${RED}ERROR: Audio file not found: ${AUDIO_FILE}${NC}"
exit 1
fi
# Verify ASR endpoint is reachable
echo -n "Checking ASR endpoint... "
if curl -sf "${ASR_URL}/" > /dev/null 2>&1 || curl -sf "${ASR_URL}/health" > /dev/null 2>&1; then
echo -e "${GREEN}OK${NC}"
else
echo -e "${RED}FAILED${NC}"
echo "Cannot reach ASR endpoint at ${ASR_URL}"
exit 1
fi
# Create output directory
mkdir -p "$OUTPUT_DIR"
# ==============================================================
# Test 1: Baseline (no hints)
# ==============================================================
echo ""
echo -e "${CYAN}--- Test 1: Baseline (no hotwords, no initial_prompt) ---${NC}"
echo -n "Transcribing... "
BASELINE_FILE="$OUTPUT_DIR/1_baseline.json"
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe" \
-F "audio_file=@${AUDIO_FILE}" \
-o "$BASELINE_FILE"
BASELINE_TEXT=$(python3 -c "
import json
d=json.load(open('$BASELINE_FILE'))
t=d.get('text','')
if isinstance(t, list):
t=' '.join(seg.get('text','') for seg in t)
print(t[:500])
" 2>/dev/null || echo "PARSE_ERROR")
echo -e "${GREEN}Done${NC}"
echo -e "Preview: ${BASELINE_TEXT:0:200}..."
echo ""
# ==============================================================
# Test 2: With hotwords only
# ==============================================================
echo -e "${CYAN}--- Test 2: With hotwords ---${NC}"
echo -n "Transcribing... "
HOTWORDS_FILE="$OUTPUT_DIR/2_with_hotwords.json"
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe&hotwords=${HOTWORDS}" \
-F "audio_file=@${AUDIO_FILE}" \
-o "$HOTWORDS_FILE"
HOTWORDS_TEXT=$(python3 -c "
import json
d=json.load(open('$HOTWORDS_FILE'))
t=d.get('text','')
if isinstance(t, list):
t=' '.join(seg.get('text','') for seg in t)
print(t[:500])
" 2>/dev/null || echo "PARSE_ERROR")
echo -e "${GREEN}Done${NC}"
echo -e "Preview: ${HOTWORDS_TEXT:0:200}..."
echo ""
# ==============================================================
# Test 3: With hotwords + initial_prompt
# ==============================================================
echo -e "${CYAN}--- Test 3: With hotwords + initial_prompt ---${NC}"
echo -n "Transcribing... "
BOTH_FILE="$OUTPUT_DIR/3_with_both.json"
ENCODED_PROMPT=$(python3 -c "import urllib.parse; print(urllib.parse.quote('$INITIAL_PROMPT'))")
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe&hotwords=${HOTWORDS}&initial_prompt=${ENCODED_PROMPT}" \
-F "audio_file=@${AUDIO_FILE}" \
-o "$BOTH_FILE"
BOTH_TEXT=$(python3 -c "
import json
d=json.load(open('$BOTH_FILE'))
t=d.get('text','')
if isinstance(t, list):
t=' '.join(seg.get('text','') for seg in t)
print(t[:500])
" 2>/dev/null || echo "PARSE_ERROR")
echo -e "${GREEN}Done${NC}"
echo -e "Preview: ${BOTH_TEXT:0:200}..."
echo ""
# ==============================================================
# Test 4: With initial_prompt only
# ==============================================================
echo -e "${CYAN}--- Test 4: With initial_prompt only ---${NC}"
echo -n "Transcribing... "
PROMPT_FILE="$OUTPUT_DIR/4_with_initial_prompt.json"
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe&initial_prompt=${ENCODED_PROMPT}" \
-F "audio_file=@${AUDIO_FILE}" \
-o "$PROMPT_FILE"
PROMPT_TEXT=$(python3 -c "
import json
d=json.load(open('$PROMPT_FILE'))
t=d.get('text','')
if isinstance(t, list):
t=' '.join(seg.get('text','') for seg in t)
print(t[:500])
" 2>/dev/null || echo "PARSE_ERROR")
echo -e "${GREEN}Done${NC}"
echo -e "Preview: ${PROMPT_TEXT:0:200}..."
echo ""
# ==============================================================
# Comparison
# ==============================================================
echo -e "${CYAN}============================================${NC}"
echo -e "${CYAN} Comparison Results${NC}"
echo -e "${CYAN}============================================${NC}"
echo ""
# Check if hotwords appear in outputs
python3 << 'PYEOF'
import json
import os
def extract_text(data):
"""Extract full text from ASR response, handling both string and segment list formats."""
text = data.get("text", "")
if isinstance(text, list):
return " ".join(seg.get("text", "") for seg in text)
return text
hotwords = ["Speakr", "CTranslate2", "PyAnnote", "WhisperX"]
output_dir = os.environ.get("OUTPUT_DIR", "temp/hotwords_test_results")
test_files = {
"1. Baseline": f"{output_dir}/1_baseline.json",
"2. Hotwords only": f"{output_dir}/2_with_hotwords.json",
"3. Hotwords + prompt": f"{output_dir}/3_with_both.json",
"4. Initial prompt only": f"{output_dir}/4_with_initial_prompt.json",
}
print(f"{'Test':<25} | {'Hotword Matches':<20} | {'Words Found'}")
print("-" * 75)
for label, filepath in test_files.items():
try:
with open(filepath) as f:
data = json.load(f)
text = extract_text(data)
found = []
for hw in hotwords:
if hw.lower() in text.lower():
found.append(hw)
match_str = f"{len(found)}/{len(hotwords)}"
found_str = ", ".join(found) if found else "(none)"
print(f"{label:<25} | {match_str:<20} | {found_str}")
except Exception as e:
print(f"{label:<25} | ERROR: {e}")
print()
print("Full outputs saved to: " + output_dir)
PYEOF
echo ""
echo -e "${CYAN}============================================${NC}"
echo -e "${CYAN} Precedence Test via Speakr API${NC}"
echo -e "${CYAN}============================================${NC}"
echo ""
echo -e "${YELLOW}To test the full precedence chain (user → folder → tag → upload form),${NC}"
echo -e "${YELLOW}use the Speakr web API with authentication:${NC}"
echo ""
echo -e "1. Set user-level defaults in Account Settings → Prompt Options"
echo -e "2. Create a tag with different hotwords/initial_prompt"
echo -e "3. Create a folder with different hotwords/initial_prompt"
echo -e "4. Upload via API and check server logs for resolved values:"
echo ""
cat << 'EXAMPLE'
# Upload with user defaults only (no tag, no folder)
curl -X POST "https://your-speakr/upload" \
-H "Authorization: Bearer YOUR_TOKEN" \
-F "file=@test.flac"
# → Should use user defaults
# Upload with a tag that has hotwords set
curl -X POST "https://your-speakr/upload" \
-H "Authorization: Bearer YOUR_TOKEN" \
-F "file=@test.flac" \
-F "tags=TAG_ID_WITH_HOTWORDS"
# → Should use tag defaults (overrides user)
# Upload with explicit form values (highest priority)
curl -X POST "https://your-speakr/upload" \
-H "Authorization: Bearer YOUR_TOKEN" \
-F "file=@test.flac" \
-F "tags=TAG_ID_WITH_HOTWORDS" \
-F "hotwords=FormOverride1,FormOverride2" \
-F "initial_prompt=Form level prompt"
# → Should use form values (overrides tag and user)
EXAMPLE
echo ""
echo -e "${GREEN}Test complete!${NC}"

187
tests/test_inquire_mode.py Normal file
View File

@@ -0,0 +1,187 @@
#!/usr/bin/env python3
"""
Test script for Inquire Mode functionality
"""
import os
import sys
# Add the parent directory to the path to import app
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.app import app, db, User, Recording, TranscriptChunk, InquireSession, Tag
def test_database_models():
"""Test that the new database models work correctly."""
with app.app_context():
print("🔍 Testing Inquire Mode Database Models...")
# Test that tables exist
from sqlalchemy import inspect
inspector = inspect(db.engine)
tables = inspector.get_table_names()
required_tables = ['transcript_chunk', 'inquire_session']
for table in required_tables:
if table in tables:
print(f"✅ Table '{table}' exists")
else:
print(f"❌ Table '{table}' missing")
return False
# Test creating sample data
try:
# Create a test user (or get existing one)
user = User.query.first()
if not user:
print("❌ No users found. Please create a user first.")
return False
print(f"📝 Using test user: {user.username}")
# Create a test recording if none exist
recording = Recording.query.filter_by(user_id=user.id).first()
if not recording:
print("❌ No recordings found. Please create a recording first.")
return False
print(f"🎵 Using test recording: {recording.title}")
# Test TranscriptChunk creation
chunk = TranscriptChunk(
recording_id=recording.id,
user_id=user.id,
chunk_index=0,
content="This is a test transcription chunk.",
start_time=0.0,
end_time=5.0,
speaker_name="Test Speaker"
)
db.session.add(chunk)
# Test InquireSession creation
session = InquireSession(
user_id=user.id,
session_name="Test Session",
filter_tags='[]',
filter_speakers='["Test Speaker"]'
)
db.session.add(session)
db.session.commit()
print("✅ Successfully created test TranscriptChunk and InquireSession")
# Clean up test data
db.session.delete(chunk)
db.session.delete(session)
db.session.commit()
print("✅ Test data cleaned up")
except Exception as e:
print(f"❌ Error testing models: {e}")
return False
return True
def test_chunking_functions():
"""Test the chunking and embedding functions."""
with app.app_context():
print("🔧 Testing Chunking Functions...")
try:
from src.app import chunk_transcription, generate_embeddings, serialize_embedding, deserialize_embedding
# Test chunking
test_text = "This is a test sentence. This is another sentence for testing. And here's a third sentence to make sure chunking works properly with longer text that should be split into multiple chunks."
chunks = chunk_transcription(test_text, max_chunk_length=100, overlap=20)
if len(chunks) > 1:
print(f"✅ Chunking works: {len(chunks)} chunks created")
else:
print("✅ Text too short for chunking (expected behavior)")
# Test embeddings (will only work if sentence-transformers is installed)
try:
embeddings = generate_embeddings(["test sentence", "another test"])
if len(embeddings) == 2:
print("✅ Embedding generation works")
# Test serialization
if embeddings[0] is not None:
serialized = serialize_embedding(embeddings[0])
deserialized = deserialize_embedding(serialized)
if deserialized is not None and len(deserialized) > 0:
print("✅ Embedding serialization/deserialization works")
else:
print("❌ Embedding serialization/deserialization failed")
else:
print("❌ Embedding generation returned wrong number of embeddings")
except Exception as e:
print(f"⚠️ Embedding test skipped (sentence-transformers may not be installed): {e}")
except Exception as e:
print(f"❌ Error testing chunking functions: {e}")
return False
return True
def test_api_imports():
"""Test that all API endpoints can be imported."""
print("🔌 Testing API Endpoint Imports...")
try:
from src.app import (
get_inquire_sessions, create_inquire_session, inquire_search,
inquire_chat, get_available_filters, process_recording_chunks_endpoint
)
print("✅ All inquire mode API endpoints imported successfully")
return True
except ImportError as e:
print(f"❌ Failed to import API endpoints: {e}")
return False
def main():
"""Run all tests."""
print("🚀 Starting Inquire Mode Tests...\n")
tests = [
("Database Models", test_database_models),
("Chunking Functions", test_chunking_functions),
("API Imports", test_api_imports)
]
results = []
for test_name, test_func in tests:
print(f"\n--- {test_name} ---")
try:
success = test_func()
results.append((test_name, success))
except Exception as e:
print(f"{test_name} failed with exception: {e}")
results.append((test_name, False))
print("\n" + "="*50)
print("📊 Test Results Summary:")
print("="*50)
all_passed = True
for test_name, success in results:
status = "✅ PASS" if success else "❌ FAIL"
print(f"{status} - {test_name}")
if not success:
all_passed = False
print("\n" + "="*50)
if all_passed:
print("🎉 All tests passed! Inquire Mode is ready to use.")
else:
print("⚠️ Some tests failed. Please check the output above.")
return all_passed
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@@ -0,0 +1,281 @@
#!/usr/bin/env python3
"""
Test script for job queue race condition fix.
This script verifies that the atomic job claiming mechanism prevents
multiple workers from claiming the same job simultaneously.
The fix uses an atomic UPDATE with WHERE clause to ensure only one
worker can claim a job, even with multiple processes/threads.
"""
import os
import sys
import threading
import time
from pathlib import Path
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
def test_atomic_job_claiming():
"""
Test that only one worker can claim a job even with concurrent attempts.
This simulates the race condition where multiple workers try to claim
the same job simultaneously.
"""
print("\n=== Testing Atomic Job Claiming ===\n")
# Import Flask app and models
from src.app import app
from src.database import db
from src.models import ProcessingJob, User, Recording
from sqlalchemy import update
with app.app_context():
# Use the first existing user for testing, or create a minimal test user
test_user = User.query.first()
if not test_user:
test_user = User(
username='test_race_condition_user',
email='test_race@example.com',
password='not_used' # Password not needed for this test
)
db.session.add(test_user)
db.session.commit()
# Create a test recording
test_recording = Recording(
user_id=test_user.id,
title='Test Race Condition Recording',
audio_path='/tmp/test_audio.mp3',
status='QUEUED'
)
db.session.add(test_recording)
db.session.commit()
# Create a test job in 'queued' status
test_job = ProcessingJob(
recording_id=test_recording.id,
user_id=test_user.id,
job_type='transcribe',
status='queued'
)
db.session.add(test_job)
db.session.commit()
job_id = test_job.id
print(f"Created test job {job_id} with status 'queued'")
# Track which threads successfully claimed the job
successful_claims = []
claim_lock = threading.Lock()
def attempt_claim(worker_id):
"""Simulate a worker attempting to claim the job."""
with app.app_context():
try:
# This is the atomic claim logic from the fix
claim_time = datetime.utcnow()
result = db.session.execute(
update(ProcessingJob)
.where(
ProcessingJob.id == job_id,
ProcessingJob.status == 'queued'
)
.values(status='processing', started_at=claim_time)
)
if result.rowcount == 1:
db.session.commit()
with claim_lock:
successful_claims.append(worker_id)
return f"Worker {worker_id}: Successfully claimed job"
else:
db.session.rollback()
return f"Worker {worker_id}: Job already claimed (rowcount=0)"
except Exception as e:
db.session.rollback()
return f"Worker {worker_id}: Error - {e}"
# Spawn multiple threads to claim simultaneously
num_workers = 10
print(f"\nSpawning {num_workers} workers to claim job {job_id} simultaneously...")
# Use a barrier to ensure all threads start at the same time
barrier = threading.Barrier(num_workers)
def worker_with_barrier(worker_id):
barrier.wait() # Wait for all threads to be ready
return attempt_claim(worker_id)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(worker_with_barrier, i): i for i in range(num_workers)}
for future in as_completed(futures):
result = future.result()
print(f" {result}")
# Verify results
print(f"\n=== Results ===")
print(f"Total workers: {num_workers}")
print(f"Successful claims: {len(successful_claims)}")
print(f"Workers that claimed: {successful_claims}")
# Check final job status
db.session.expire_all()
final_job = db.session.get(ProcessingJob, job_id)
print(f"Final job status: {final_job.status}")
# Cleanup
db.session.delete(final_job)
db.session.delete(test_recording)
db.session.commit()
# Assert only one worker claimed the job
assert len(successful_claims) == 1, f"Expected 1 successful claim, got {len(successful_claims)}"
assert final_job.status == 'processing', f"Expected status 'processing', got {final_job.status}"
print("\n[PASS] Only one worker successfully claimed the job!")
return True
def test_multiple_jobs_fair_distribution():
"""
Test that multiple jobs are distributed fairly across workers.
"""
print("\n=== Testing Multiple Jobs Distribution ===\n")
from src.app import app
from src.database import db
from src.models import ProcessingJob, User, Recording
from sqlalchemy import update
with app.app_context():
# Use the first existing user for testing
test_user = User.query.first()
if not test_user:
test_user = User(
username='test_distribution_user',
email='test_dist@example.com',
password='not_used'
)
db.session.add(test_user)
db.session.commit()
# Create multiple test jobs
num_jobs = 5
job_ids = []
recording_ids = []
for i in range(num_jobs):
recording = Recording(
user_id=test_user.id,
title=f'Test Distribution Recording {i}',
audio_path=f'/tmp/test_audio_{i}.mp3',
status='QUEUED'
)
db.session.add(recording)
db.session.commit()
recording_ids.append(recording.id)
job = ProcessingJob(
recording_id=recording.id,
user_id=test_user.id,
job_type='transcribe',
status='queued'
)
db.session.add(job)
db.session.commit()
job_ids.append(job.id)
print(f"Created {num_jobs} test jobs: {job_ids}")
# Have workers claim jobs
claimed_jobs = []
def claim_any_job(worker_id):
with app.app_context():
# Find a queued job
candidate = ProcessingJob.query.filter(
ProcessingJob.status == 'queued',
ProcessingJob.job_type == 'transcribe'
).first()
if not candidate:
return None
# Atomic claim
result = db.session.execute(
update(ProcessingJob)
.where(
ProcessingJob.id == candidate.id,
ProcessingJob.status == 'queued'
)
.values(status='processing', started_at=datetime.utcnow())
)
if result.rowcount == 1:
db.session.commit()
return candidate.id
else:
db.session.rollback()
return None
# Each "worker" claims one job
for i in range(num_jobs + 2): # Extra attempts to ensure no double claims
job_id = claim_any_job(i)
if job_id:
claimed_jobs.append(job_id)
print(f" Worker {i} claimed job {job_id}")
else:
print(f" Worker {i} found no available jobs")
print(f"\nClaimed jobs: {claimed_jobs}")
print(f"Unique jobs claimed: {len(set(claimed_jobs))}")
# Verify no duplicates
assert len(claimed_jobs) == len(set(claimed_jobs)), "Duplicate job claims detected!"
assert len(claimed_jobs) == num_jobs, f"Expected {num_jobs} claims, got {len(claimed_jobs)}"
# Cleanup
for job_id in job_ids:
job = db.session.get(ProcessingJob, job_id)
if job:
db.session.delete(job)
for rec_id in recording_ids:
rec = db.session.get(Recording, rec_id)
if rec:
db.session.delete(rec)
db.session.commit()
print("\n[PASS] All jobs claimed exactly once!")
return True
if __name__ == '__main__':
print("=" * 60)
print("Job Queue Race Condition Tests")
print("=" * 60)
try:
test_atomic_job_claiming()
test_multiple_jobs_fair_distribution()
print("\n" + "=" * 60)
print("All tests passed!")
print("=" * 60)
except AssertionError as e:
print(f"\n[FAIL] Test failed: {e}")
sys.exit(1)
except Exception as e:
print(f"\n[ERROR] Unexpected error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

58
tests/test_json_fix.py Normal file
View File

@@ -0,0 +1,58 @@
import json
import sys
import os
# Add the parent directory to the path to import app
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.app import auto_close_json, safe_json_loads
def run_tests():
"""Runs a series of tests for the JSON fixing functions."""
test_cases_auto_close = {
"Unterminated string": ('{"title": "Test", "summary": "This is a test', '{"title": "Test", "summary": "This is a test"}'),
"Missing closing brace": ('{"title": "Test", "summary": "This is a test"}', '{"title": "Test", "summary": "This is a test"}'),
"Missing closing bracket": ('[{"item": 1}, {"item": 2}', '[{"item": 1}, {"item": 2}]'),
"Nested unterminated": ('{"data": {"items": [1, 2', '{"data": {"items": [1, 2]}}'),
"String at the end": ('{"key": "value', '{"key": "value"}'),
"Empty string": ('', ''),
"Already valid": ('{"a": 1}', '{"a": 1}'),
"Complex nested object": ('{"a": {"b": {"c": [1, 2, {"d": "e' , '{"a": {"b": {"c": [1, 2, {"d": "e"}]}}}')
}
print("--- Testing auto_close_json ---")
for name, (input_str, expected_str) in test_cases_auto_close.items():
result = auto_close_json(input_str)
print(f"Test: {name}")
print(f" Input: '{input_str}'")
print(f" Output: '{result}'")
print(f" Expected: '{expected_str}'")
assert result == expected_str, f"Failed: {name}"
print(" Result: PASSED\n")
test_cases_safe_loads = {
"Unterminated string": '{"title": "Test", "summary": "This is a test',
"Markdown with unterminated JSON": '```json\n{"title": "Test", "summary": "This is a test\n```',
"Missing closing brace": '{"title": "Test", "summary": "This is a test"}',
"Valid JSON": '{"title": "Complete", "summary": "This is a complete JSON."}',
"JSON with escaped quotes": '{"title": "Escaped", "summary": "This is a \\"test\\" with quotes."}',
"Invalid JSON": 'this is not json',
}
print("\n--- Testing safe_json_loads ---")
for name, input_str in test_cases_safe_loads.items():
result = safe_json_loads(input_str)
print(f"Test: {name}")
print(f" Input: '{input_str}'")
print(f" Output: {result}")
if name == "Invalid JSON":
assert result is None, f"Failed: {name}"
print(" Result: PASSED (Correctly returned None)\n")
else:
assert isinstance(result, dict), f"Failed: {name}"
print(" Result: PASSED\n")
if __name__ == "__main__":
run_tests()
print("All tests completed successfully!")

View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python3
"""
Test suite for JSON preprocessing functionality in Speakr app.
Tests the safe_json_loads function with various malformed JSON scenarios.
"""
import sys
import os
import json
import unittest
from unittest.mock import Mock
# Add the app directory to the path so we can import from app.py
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Mock the Flask app and logger for testing
class MockApp:
def __init__(self):
self.logger = Mock()
# Set up the mock app before importing
app = MockApp()
# Import the functions we want to test
from src.app import safe_json_loads, preprocess_json_escapes, extract_json_object
class TestJSONPreprocessing(unittest.TestCase):
"""Test cases for JSON preprocessing functionality."""
def test_valid_json(self):
"""Test that valid JSON is parsed correctly."""
valid_json = '{"title": "Test Meeting", "summary": "This is a test summary"}'
result = safe_json_loads(valid_json)
expected = {"title": "Test Meeting", "summary": "This is a test summary"}
self.assertEqual(result, expected)
def test_json_with_markdown_code_blocks(self):
"""Test JSON wrapped in markdown code blocks."""
markdown_json = '''```json
{
"title": "Meeting Notes",
"summary": "Key points discussed"
}
```'''
result = safe_json_loads(markdown_json)
expected = {"title": "Meeting Notes", "summary": "Key points discussed"}
self.assertEqual(result, expected)
def test_json_with_unescaped_quotes(self):
"""Test JSON with unescaped quotes in string values."""
malformed_json = '{"title": "John said "Hello world" to everyone", "summary": "Meeting summary"}'
result = safe_json_loads(malformed_json)
expected = {"title": 'John said "Hello world" to everyone', "summary": "Meeting summary"}
self.assertEqual(result, expected)
def test_json_with_mixed_quotes(self):
"""Test JSON with mixed quote scenarios."""
malformed_json = '{"title": "Alice\'s "big idea" presentation", "summary": "She said "this will change everything""}'
result = safe_json_loads(malformed_json)
self.assertIsInstance(result, dict)
self.assertIn("title", result)
self.assertIn("summary", result)
def test_json_with_newlines_and_special_chars(self):
"""Test JSON with newlines and special characters."""
malformed_json = '''{"title": "Complex Meeting", "summary": "Discussion about:\n- Point 1\n- Point 2 with "quotes"\n- Point 3"}'''
result = safe_json_loads(malformed_json)
self.assertIsInstance(result, dict)
self.assertIn("title", result)
self.assertIn("summary", result)
def test_empty_or_invalid_input(self):
"""Test handling of empty or invalid input."""
# Empty string
result = safe_json_loads("", {"default": "value"})
self.assertEqual(result, {"default": "value"})
# None input
result = safe_json_loads(None, {"default": "value"})
self.assertEqual(result, {"default": "value"})
# Non-string input
result = safe_json_loads(123, {"default": "value"})
self.assertEqual(result, {"default": "value"})
def test_completely_malformed_json(self):
"""Test completely malformed JSON that can't be fixed."""
malformed_json = '{"title": "Test", "summary": unclosed string and missing quotes}'
result = safe_json_loads(malformed_json, {"error": "fallback"})
self.assertEqual(result, {"error": "fallback"})
def test_json_with_nested_quotes(self):
"""Test JSON with deeply nested quote scenarios."""
malformed_json = '{"title": "Meeting about "Project Alpha" and "Project Beta"", "summary": "Both projects involve "cutting-edge" technology"}'
result = safe_json_loads(malformed_json)
self.assertIsInstance(result, dict)
# Should have successfully parsed something
self.assertTrue(len(result) > 0)
def test_json_array_format(self):
"""Test JSON array format (for transcription data)."""
json_array = '[{"speaker": "John", "sentence": "Hello everyone"}, {"speaker": "Jane", "sentence": "Good morning"}]'
result = safe_json_loads(json_array)
expected = [{"speaker": "John", "sentence": "Hello everyone"}, {"speaker": "Jane", "sentence": "Good morning"}]
self.assertEqual(result, expected)
def test_preprocess_json_escapes_function(self):
"""Test the preprocess_json_escapes function directly."""
input_json = '{"title": "John said "Hello" to Mary", "summary": "Simple test"}'
processed = preprocess_json_escapes(input_json)
# Should be valid JSON after preprocessing
result = json.loads(processed)
self.assertIsInstance(result, dict)
self.assertIn("title", result)
self.assertIn("summary", result)
def test_extract_json_object_function(self):
"""Test the extract_json_object function directly."""
# Test with extra text around JSON object
text_with_json = 'Here is some text {"title": "Test", "summary": "Content"} and more text'
extracted = extract_json_object(text_with_json)
result = json.loads(extracted)
expected = {"title": "Test", "summary": "Content"}
self.assertEqual(result, expected)
# Test with JSON array
text_with_array = 'Some text [{"item": "one"}, {"item": "two"}] more text'
extracted = extract_json_object(text_with_array)
result = json.loads(extracted)
expected = [{"item": "one"}, {"item": "two"}]
self.assertEqual(result, expected)
def test_real_world_llm_response_scenarios(self):
"""Test real-world scenarios that might come from LLM responses."""
# Scenario 1: LLM response with explanation text
llm_response1 = '''Here's the JSON response you requested:
```json
{
"title": "Q3 Planning Meeting",
"summary": "We discussed the "new initiative" and John's "breakthrough idea" for next quarter."
}
```
This should help with your transcription needs.'''
result1 = safe_json_loads(llm_response1)
self.assertIsInstance(result1, dict)
self.assertIn("title", result1)
self.assertIn("summary", result1)
# Scenario 2: LLM response with unescaped quotes and no code blocks
llm_response2 = '{"title": "Team Standup", "summary": "Alice mentioned "the deadline is tight" and Bob said "we need more resources""}'
result2 = safe_json_loads(llm_response2)
self.assertIsInstance(result2, dict)
self.assertIn("title", result2)
self.assertIn("summary", result2)
# Scenario 3: LLM response with speaker identification
llm_response3 = '''{"SPEAKER_00": "John Smith", "SPEAKER_01": "Jane "The Expert" Doe", "SPEAKER_02": "Bob"}'''
result3 = safe_json_loads(llm_response3)
self.assertIsInstance(result3, dict)
self.assertTrue(len(result3) >= 2) # Should have parsed at least some speakers
def test_fallback_strategies(self):
"""Test that different parsing strategies work as fallbacks."""
# Test ast.literal_eval fallback for simple cases
simple_dict = "{'title': 'Simple', 'summary': 'Test'}"
result = safe_json_loads(simple_dict)
expected = {"title": "Simple", "summary": "Test"}
self.assertEqual(result, expected)
# Test regex extraction fallback
messy_response = 'Some text before {"title": "Extracted", "summary": "From regex"} some text after'
result = safe_json_loads(messy_response)
expected = {"title": "Extracted", "summary": "From regex"}
self.assertEqual(result, expected)
def test_performance_with_large_content(self):
"""Test performance with larger JSON content."""
large_summary = "This is a very long summary. " * 100 # Create a long string
large_json = f'{{"title": "Large Content Test", "summary": "{large_summary}"}}'
result = safe_json_loads(large_json)
self.assertIsInstance(result, dict)
self.assertIn("title", result)
self.assertIn("summary", result)
self.assertEqual(result["title"], "Large Content Test")
def run_comprehensive_test():
"""Run a comprehensive test with various malformed JSON examples."""
print("🧪 Running comprehensive JSON preprocessing tests...\n")
test_cases = [
{
"name": "Valid JSON",
"input": '{"title": "Test", "summary": "Valid JSON"}',
"should_succeed": True
},
{
"name": "Unescaped quotes in title",
"input": '{"title": "Meeting about "Project X"", "summary": "Discussion summary"}',
"should_succeed": True
},
{
"name": "Multiple unescaped quotes",
"input": '{"title": "John said "Hello" and Mary replied "Hi there"", "summary": "Conversation log"}',
"should_succeed": True
},
{
"name": "Markdown code block",
"input": '```json\n{"title": "Wrapped", "summary": "In code block"}\n```',
"should_succeed": True
},
{
"name": "Mixed quotes and apostrophes",
"input": '{"title": "Alice\'s "big idea" presentation", "summary": "She said it\'s "revolutionary""}',
"should_succeed": True
},
{
"name": "JSON with newlines",
"input": '{"title": "Multi-line", "summary": "Line 1\\nLine 2 with \\"quotes\\"\\nLine 3"}',
"should_succeed": True
},
{
"name": "Completely malformed",
"input": '{"title": "Test", "summary": this is not valid json at all}',
"should_succeed": False
},
{
"name": "Empty string",
"input": "",
"should_succeed": False
}
]
passed = 0
failed = 0
for i, test_case in enumerate(test_cases, 1):
print(f"Test {i}: {test_case['name']}")
print(f"Input: {test_case['input'][:100]}{'...' if len(test_case['input']) > 100 else ''}")
try:
result = safe_json_loads(test_case['input'], {"error": "fallback"})
if test_case['should_succeed']:
if result != {"error": "fallback"} and isinstance(result, dict):
print("✅ PASSED - Successfully parsed JSON")
passed += 1
else:
print("❌ FAILED - Expected successful parsing but got fallback")
failed += 1
else:
if result == {"error": "fallback"}:
print("✅ PASSED - Correctly returned fallback for malformed JSON")
passed += 1
else:
print("❌ FAILED - Expected fallback but got parsed result")
failed += 1
except Exception as e:
print(f"❌ FAILED - Exception occurred: {e}")
failed += 1
print("-" * 50)
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
return failed == 0
if __name__ == "__main__":
print("🚀 Starting JSON Preprocessing Tests for Speakr App\n")
# Run the comprehensive manual test
manual_success = run_comprehensive_test()
print("\n" + "="*60)
print("🔬 Running Unit Tests")
print("="*60)
# Run the unit tests
unittest.main(argv=[''], exit=False, verbosity=2)
if manual_success:
print("\n🎉 All tests completed! JSON preprocessing should handle LLM response issues gracefully.")
else:
print("\n⚠️ Some tests failed. Please review the implementation.")

View File

@@ -0,0 +1,340 @@
#!/usr/bin/env python3
"""
Standalone test for JSON preprocessing functionality.
Tests the safe_json_loads function with various malformed JSON scenarios.
"""
import json
import re
import ast
from unittest.mock import Mock
# Mock logger for testing
class MockLogger:
def warning(self, msg): print(f"WARNING: {msg}")
def info(self, msg): print(f"INFO: {msg}")
def debug(self, msg): print(f"DEBUG: {msg}")
def error(self, msg): print(f"ERROR: {msg}")
# Create mock app with logger
class MockApp:
logger = MockLogger()
app = MockApp()
def safe_json_loads(json_string, fallback_value=None):
"""
Safely parse JSON with preprocessing to handle common LLM JSON formatting issues.
Args:
json_string (str): The JSON string to parse
fallback_value: Value to return if parsing fails (default: None)
Returns:
Parsed JSON object or fallback_value if parsing fails
"""
if not json_string or not isinstance(json_string, str):
app.logger.warning(f"Invalid JSON input: {type(json_string)} - {json_string}")
return fallback_value
# Step 1: Clean the input string
cleaned_json = json_string.strip()
# Step 2: Extract JSON from markdown code blocks if present
json_match = re.search(r'```(?:json)?\s*(.*?)\s*```', cleaned_json, re.DOTALL)
if json_match:
cleaned_json = json_match.group(1).strip()
# Step 3: Try multiple parsing strategies
parsing_strategies = [
# Strategy 1: Direct parsing (for well-formed JSON)
lambda x: json.loads(x),
# Strategy 2: Fix common escape issues
lambda x: json.loads(preprocess_json_escapes(x)),
# Strategy 3: Use ast.literal_eval as fallback for simple cases
lambda x: ast.literal_eval(x) if x.startswith(('{', '[')) else None,
# Strategy 4: Extract JSON object/array using regex
lambda x: json.loads(extract_json_object(x)),
]
for i, strategy in enumerate(parsing_strategies):
try:
result = strategy(cleaned_json)
if result is not None:
if i > 0: # Log if we had to use a fallback strategy
app.logger.info(f"JSON parsed successfully using strategy {i+1}")
return result
except (json.JSONDecodeError, ValueError, SyntaxError) as e:
if i == 0: # Only log the first failure to avoid spam
app.logger.debug(f"JSON parsing strategy {i+1} failed: {e}")
continue
# All strategies failed
app.logger.error(f"All JSON parsing strategies failed for: {cleaned_json[:200]}...")
return fallback_value
def preprocess_json_escapes(json_string):
"""
Preprocess JSON string to fix common escape issues from LLM responses.
Uses a more sophisticated approach to handle nested quotes properly.
"""
if not json_string:
return json_string
result = []
i = 0
in_string = False
escape_next = False
expecting_value = False # Track if we're expecting a value (after :)
while i < len(json_string):
char = json_string[i]
if escape_next:
# This character is escaped, add it as-is
result.append(char)
escape_next = False
elif char == '\\':
# This is an escape character
result.append(char)
escape_next = True
elif char == ':' and not in_string:
# We found a colon, next string will be a value
result.append(char)
expecting_value = True
elif char == ',' and not in_string:
# We found a comma, reset expecting_value
result.append(char)
expecting_value = False
elif char == '"':
if not in_string:
# Starting a string
in_string = True
result.append(char)
else:
# We're in a string, check if this quote should be escaped
# Look ahead to see if this is the end of the string value
j = i + 1
while j < len(json_string) and json_string[j].isspace():
j += 1
# For keys (not expecting_value), only end on colon
# For values (expecting_value), end on comma, closing brace, or closing bracket
if expecting_value:
end_chars = ',}]'
else:
end_chars = ':'
if j < len(json_string) and json_string[j] in end_chars:
# This is the end of the string
in_string = False
result.append(char)
if not expecting_value:
# We just finished a key, next will be expecting value
expecting_value = True
else:
# This is an inner quote that should be escaped
result.append('\\"')
else:
result.append(char)
i += 1
return ''.join(result)
def extract_json_object(text):
"""
Extract the first complete JSON object or array from text using regex.
"""
# Look for JSON object
obj_match = re.search(r'\{.*\}', text, re.DOTALL)
if obj_match:
return obj_match.group(0)
# Look for JSON array
arr_match = re.search(r'\[.*\]', text, re.DOTALL)
if arr_match:
return arr_match.group(0)
# Return original if no JSON structure found
return text
def run_comprehensive_test():
"""Run a comprehensive test with various malformed JSON examples."""
print("🧪 Running comprehensive JSON preprocessing tests...\n")
test_cases = [
{
"name": "Valid JSON",
"input": '{"title": "Test", "summary": "Valid JSON"}',
"should_succeed": True
},
{
"name": "Unescaped quotes in title",
"input": '{"title": "Meeting about "Project X"", "summary": "Discussion summary"}',
"should_succeed": True
},
{
"name": "Multiple unescaped quotes",
"input": '{"title": "John said "Hello" and Mary replied "Hi there"", "summary": "Conversation log"}',
"should_succeed": True
},
{
"name": "Markdown code block",
"input": '```json\n{"title": "Wrapped", "summary": "In code block"}\n```',
"should_succeed": True
},
{
"name": "Mixed quotes and apostrophes",
"input": '{"title": "Alice\'s "big idea" presentation", "summary": "She said it\'s "revolutionary""}',
"should_succeed": True
},
{
"name": "JSON with newlines",
"input": '{"title": "Multi-line", "summary": "Line 1\\nLine 2 with \\"quotes\\"\\nLine 3"}',
"should_succeed": True
},
{
"name": "LLM response with explanation",
"input": '''Here's the JSON:
```json
{
"title": "Q3 Planning",
"summary": "We discussed the "new initiative" for next quarter."
}
```
Hope this helps!''',
"should_succeed": True
},
{
"name": "Speaker identification with quotes",
"input": '{"SPEAKER_00": "John Smith", "SPEAKER_01": "Jane "The Expert" Doe", "SPEAKER_02": "Bob"}',
"should_succeed": True
},
{
"name": "Completely malformed",
"input": '{"title": "Test", "summary": this is not valid json at all}',
"should_succeed": False
},
{
"name": "Empty string",
"input": "",
"should_succeed": False
}
]
passed = 0
failed = 0
for i, test_case in enumerate(test_cases, 1):
print(f"Test {i}: {test_case['name']}")
print(f"Input: {test_case['input'][:100]}{'...' if len(test_case['input']) > 100 else ''}")
try:
result = safe_json_loads(test_case['input'], {"error": "fallback"})
if test_case['should_succeed']:
if result != {"error": "fallback"} and isinstance(result, (dict, list)):
print("✅ PASSED - Successfully parsed JSON")
print(f" Result: {result}")
passed += 1
else:
print("❌ FAILED - Expected successful parsing but got fallback")
failed += 1
else:
if result == {"error": "fallback"}:
print("✅ PASSED - Correctly returned fallback for malformed JSON")
passed += 1
else:
print("❌ FAILED - Expected fallback but got parsed result")
print(f" Unexpected result: {result}")
failed += 1
except Exception as e:
print(f"❌ FAILED - Exception occurred: {e}")
failed += 1
print("-" * 50)
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
return failed == 0
def test_preprocessing_function():
"""Test the preprocessing function directly."""
print("\n🔧 Testing preprocessing function directly...\n")
test_input = '{"title": "Meeting about "Project X"", "summary": "Discussion summary"}'
print(f"Original: {test_input}")
processed = preprocess_json_escapes(test_input)
print(f"Processed: {processed}")
try:
result = json.loads(processed)
print(f"✅ Successfully parsed: {result}")
except json.JSONDecodeError as e:
print(f"❌ Still failed: {e}")
def test_specific_scenarios():
"""Test specific real-world scenarios."""
print("\n🎯 Testing specific LLM response scenarios...\n")
# Test case from the original issue
gemini_response = '''{"title": "Meeting about "Project Phoenix" and budget allocation", "summary": "The team discussed John's "breakthrough idea" and Mary said "this will change everything" during the Q3 planning session."}'''
print("Testing Gemini-style response with unescaped quotes:")
print(f"Input: {gemini_response}")
# Test preprocessing directly
processed = preprocess_json_escapes(gemini_response)
print(f"Processed: {processed}")
result = safe_json_loads(gemini_response)
if isinstance(result, dict) and "title" in result and "summary" in result:
print("✅ SUCCESS - Parsed Gemini response correctly!")
print(f"Title: {result['title']}")
print(f"Summary: {result['summary'][:100]}...")
else:
print("❌ FAILED - Could not parse Gemini response")
print(f"Result: {result}")
print("-" * 50)
# Test speaker identification scenario
speaker_response = '''{"SPEAKER_00": "John "The Manager" Smith", "SPEAKER_01": "Alice Johnson", "SPEAKER_02": "Bob "Tech Lead" Wilson"}'''
print("Testing speaker identification with quotes in names:")
print(f"Input: {speaker_response}")
# Test preprocessing directly
processed = preprocess_json_escapes(speaker_response)
print(f"Processed: {processed}")
result = safe_json_loads(speaker_response)
if isinstance(result, dict) and len(result) >= 3:
print("✅ SUCCESS - Parsed speaker identification correctly!")
for speaker, name in result.items():
print(f" {speaker}: {name}")
else:
print("❌ FAILED - Could not parse speaker identification")
print(f"Result: {result}")
if __name__ == "__main__":
print("🚀 Starting Standalone JSON Preprocessing Tests\n")
# Test preprocessing function directly
test_preprocessing_function()
# Run the comprehensive test
success = run_comprehensive_test()
# Test specific scenarios
test_specific_scenarios()
if success:
print("\n🎉 All tests completed successfully! JSON preprocessing should handle LLM response issues gracefully.")
else:
print("\n⚠️ Some tests failed. The implementation may need refinement.")

View File

@@ -0,0 +1,251 @@
"""
Test suite to ensure database migrations are compatible with both SQLite and PostgreSQL.
These tests scan the init_db.py file for patterns that would break on PostgreSQL,
such as SQLite-only boolean defaults (0/1 instead of FALSE/TRUE) and unquoted
reserved keywords.
Run with: python tests/test_migration_compatibility.py
"""
import re
import unittest
import os
class TestMigrationCompatibility(unittest.TestCase):
"""Tests to ensure init_db.py uses cross-database compatible SQL."""
@classmethod
def setUpClass(cls):
"""Load init_db.py content once for all tests."""
# Find the project root
test_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(test_dir)
init_db_path = os.path.join(project_root, 'src', 'init_db.py')
with open(init_db_path, 'r') as f:
cls.content = f.read()
def test_no_raw_boolean_defaults_in_alter_table(self):
"""
Ensure no raw ALTER TABLE statements use SQLite-only boolean defaults.
The pattern 'BOOLEAN DEFAULT 0' or 'BOOLEAN DEFAULT 1' in raw SQL
will fail on PostgreSQL, which requires 'DEFAULT FALSE' or 'DEFAULT TRUE'.
Using add_column_if_not_exists() handles this conversion automatically.
"""
# Pattern to find raw SQL with text() that has BOOLEAN DEFAULT 0/1
# This matches: text('... BOOLEAN DEFAULT 0 ...') or text("...")
pattern = r"conn\.execute\s*\(\s*text\s*\(['\"]([^'\"]*BOOLEAN\s+DEFAULT\s+[01][^'\"]*)['\"]"
matches = re.findall(pattern, self.content, re.IGNORECASE)
# Filter out false positives - we're looking for raw ALTER TABLE statements
# not UPDATE statements or other SQL that legitimately uses 0/1
problematic = []
for match in matches:
match_upper = match.upper()
# Only flag if it's an ALTER TABLE with BOOLEAN DEFAULT 0/1
if 'ALTER TABLE' in match_upper and 'BOOLEAN' in match_upper:
if 'DEFAULT 0' in match or 'DEFAULT 1' in match:
problematic.append(match)
self.assertEqual(
len(problematic), 0,
f"Found SQLite-only boolean defaults in raw ALTER TABLE statements. "
f"Use add_column_if_not_exists() instead:\n" +
"\n".join(f" - {m[:100]}..." if len(m) > 100 else f" - {m}" for m in problematic)
)
def test_no_boolean_integer_comparisons_in_raw_sql(self):
"""
Ensure raw SQL doesn't compare boolean columns to integers (0/1).
PostgreSQL strictly separates boolean and integer types:
- 'column = 1' fails with 'operator does not exist: boolean = integer'
- 'column = TRUE' works on both SQLite (3.23+) and PostgreSQL
Known boolean columns in migrations: protect_from_deletion, email_verified,
auto_share_on_apply, share_with_group_lead, is_inbox, is_highlighted,
deletion_exempt, is_admin, can_share_publicly.
"""
boolean_columns = [
'protect_from_deletion', 'email_verified', 'auto_share_on_apply',
'share_with_group_lead', 'is_inbox', 'is_highlighted',
'deletion_exempt', 'is_admin', 'can_share_publicly',
'auto_speaker_labelling', 'auto_summarization'
]
# Find raw SQL in text() calls
sql_pattern = r"text\s*\(\s*['\"\"]\"\"(.*?)['\"\"]\"\"?\s*\)"
# Simpler: find lines with known boolean column = 0 or = 1
problematic = []
for col in boolean_columns:
# Match: column = 0 or column = 1 (not = TRUE/FALSE)
pattern = rf"{col}\s*=\s*[01]\b"
matches = re.finditer(pattern, self.content, re.IGNORECASE)
for match in matches:
# Get surrounding context to check if it's in a text() SQL call
start = max(0, match.start() - 200)
context = self.content[start:match.end() + 50]
if 'text(' in context and 'sqlite_master' not in context:
problematic.append(f"{col}: ...{match.group()}...")
self.assertEqual(
len(problematic), 0,
f"Found boolean columns compared to integers in raw SQL. "
f"Use TRUE/FALSE instead of 1/0 for PostgreSQL compatibility:\n" +
"\n".join(f" - {p}" for p in problematic)
)
def test_reserved_keywords_quoted_in_index_creation(self):
"""
Ensure reserved keywords like 'user' are properly quoted in index creation.
Raw SQL like 'CREATE INDEX ... ON user (column)' will fail on some databases
because 'user' is a reserved keyword. It should be quoted as "user" or use
the create_index_if_not_exists() utility.
"""
reserved_keywords = ['user', 'order', 'group', 'table', 'select', 'index']
problematic = []
for keyword in reserved_keywords:
# Pattern to find unquoted reserved keyword after ON in index creation
# Matches: CREATE INDEX ... ON user ( but not ON "user" or ON `user`
pattern = rf"CREATE\s+(?:UNIQUE\s+)?INDEX[^;]*\s+ON\s+{keyword}\s*\("
matches = re.findall(pattern, self.content, re.IGNORECASE)
for match in matches:
# Skip if the keyword is already quoted
if f'"{keyword}"' in match.lower() or f'`{keyword}`' in match.lower():
continue
problematic.append((keyword, match[:80]))
self.assertEqual(
len(problematic), 0,
f"Found unquoted reserved keywords in index creation. "
f"Use create_index_if_not_exists() or quote the table name:\n" +
"\n".join(f" - '{kw}' in: {sql}..." for kw, sql in problematic)
)
def test_add_column_uses_utility(self):
"""
Ensure most ADD COLUMN operations use add_column_if_not_exists().
Direct ALTER TABLE ADD COLUMN statements should use the utility function
to ensure cross-database compatibility with boolean defaults and quoting.
"""
# Count direct ALTER TABLE ADD COLUMN in text() calls
direct_pattern = r"conn\.execute\s*\(\s*text\s*\(['\"][^'\"]*ALTER\s+TABLE[^'\"]*ADD\s+COLUMN"
direct_matches = re.findall(direct_pattern, self.content, re.IGNORECASE)
# Count uses of add_column_if_not_exists
utility_pattern = r"add_column_if_not_exists\s*\("
utility_matches = re.findall(utility_pattern, self.content)
# We expect most ADD COLUMN operations to use the utility
# Allow some direct usage for special cases (e.g., table recreation)
# but utility usage should significantly outnumber direct usage
self.assertGreater(
len(utility_matches), len(direct_matches),
f"Found {len(direct_matches)} direct ALTER TABLE ADD COLUMN statements "
f"vs {len(utility_matches)} add_column_if_not_exists() calls. "
f"Consider using the utility function for cross-database compatibility."
)
def test_incompatible_types_handled_by_utility(self):
"""
Ensure columns with PostgreSQL-incompatible types (DATETIME, BLOB) are
added through add_column_if_not_exists() which auto-converts them,
and NOT via raw ALTER TABLE statements that would bypass conversion.
PostgreSQL type differences:
- DATETIME -> TIMESTAMP
- BLOB -> BYTEA
"""
incompatible_types = ['DATETIME', 'BLOB']
# Check for raw ALTER TABLE statements using incompatible types
for sql_type in incompatible_types:
pattern = rf"conn\.execute\s*\(\s*text\s*\(['\"][^'\"]*ALTER\s+TABLE[^'\"]*\b{sql_type}\b[^'\"]*['\"]"
matches = re.findall(pattern, self.content, re.IGNORECASE)
self.assertEqual(
len(matches), 0,
f"Found raw ALTER TABLE statements using '{sql_type}' which is incompatible with PostgreSQL. "
f"Use add_column_if_not_exists() which auto-converts types:\n" +
"\n".join(f" - {m[:100]}..." if len(m) > 100 else f" - {m}" for m in matches)
)
# Verify that add_column_if_not_exists calls using these types exist
# (confirming they go through the utility which handles conversion)
for sql_type in incompatible_types:
pattern = rf"add_column_if_not_exists\s*\([^)]*['\"]({sql_type})['\"]"
matches = re.findall(pattern, self.content, re.IGNORECASE)
# Just informational - these are fine because the utility converts them
def test_no_double_quoted_string_defaults(self):
"""
Ensure no SQL DEFAULT values use double-quoted strings.
In SQL, double quotes denote identifiers (column/table names), not string
literals. SQLite tolerates this, but PostgreSQL will interpret DEFAULT "en"
as a reference to a column named "en" and fail with 'column "en" does not exist'.
String defaults must use single quotes: DEFAULT 'en'
"""
# Match DEFAULT followed by a double-quoted string value
pattern = r'DEFAULT\s+"[^"]*"'
lines = self.content.splitlines()
problematic = []
for i, line in enumerate(lines, 1):
if re.search(pattern, line, re.IGNORECASE):
problematic.append(f" Line {i}: {line.strip()}")
self.assertEqual(
len(problematic), 0,
f"Found double-quoted string defaults in init_db.py. "
f"PostgreSQL interprets double quotes as column identifiers, not string literals. "
f"Use single quotes instead (e.g., DEFAULT 'en' not DEFAULT \"en\"):\n" +
"\n".join(problematic)
)
def test_create_index_uses_utility_for_user_table(self):
"""
Ensure index creation on 'user' table uses create_index_if_not_exists().
The 'user' table name is a reserved keyword that requires special quoting.
Using create_index_if_not_exists() handles this automatically.
"""
# Find all index creation on user table
pattern = r"CREATE\s+(?:UNIQUE\s+)?INDEX[^;]*ON\s+[\"'`]?user[\"'`]?\s*\("
# Count raw index creation on user table in text() calls
raw_pattern = r"conn\.execute\s*\(\s*text\s*\(['\"][^'\"]*CREATE\s+(?:UNIQUE\s+)?INDEX[^'\"]*ON\s+[\"'`]?user"
raw_matches = re.findall(raw_pattern, self.content, re.IGNORECASE)
# Count uses of create_index_if_not_exists for user table
utility_pattern = r"create_index_if_not_exists\s*\([^)]*['\"]user['\"]"
utility_matches = re.findall(utility_pattern, self.content, re.IGNORECASE)
# All index creation on user table should use the utility
# (excluding table recreation scenarios which have their own quoting)
if len(raw_matches) > 0:
# Check if these are in table recreation blocks (acceptable)
table_recreation_pattern = r"CREATE\s+TABLE\s+user_new"
has_table_recreation = re.search(table_recreation_pattern, self.content, re.IGNORECASE)
if not has_table_recreation or len(raw_matches) > 1:
self.fail(
f"Found {len(raw_matches)} raw CREATE INDEX statements on 'user' table. "
f"Use create_index_if_not_exists() for proper quoting of reserved keywords."
)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,165 @@
"""
Integration test for database migrations against a real database engine.
Runs initialize_database() and verifies that all tables and critical columns
are created successfully. Works with both SQLite (default, for local runs)
and PostgreSQL (when TEST_DATABASE_URI env var is set).
IMPORTANT: This test uses TEST_DATABASE_URI (not SQLALCHEMY_DATABASE_URI) to
avoid accidentally connecting to and destroying a real application database.
Usage:
# Local (SQLite in-memory, safe):
python tests/test_postgres_migrations.py
# Against PostgreSQL (CI or explicit testing):
TEST_DATABASE_URI=postgresql://user:pass@localhost:5432/testdb \
python tests/test_postgres_migrations.py
"""
import os
import sys
import unittest
# Ensure project root is on the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from flask import Flask
from src.database import db
# Importing models registers them with SQLAlchemy so create_all() builds all tables
import src.models # noqa: F401
from src.init_db import initialize_database
def create_test_app():
"""Create a minimal Flask app for testing database operations.
Uses TEST_DATABASE_URI env var (NOT SQLALCHEMY_DATABASE_URI) to prevent
accidental connection to production/dev databases.
"""
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = os.environ.get(
'TEST_DATABASE_URI', 'sqlite://' # in-memory SQLite by default
)
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['TESTING'] = True
db.init_app(app)
return app
class TestDatabaseMigrations(unittest.TestCase):
"""Test that initialize_database() runs cleanly against the configured DB engine."""
@classmethod
def setUpClass(cls):
cls.app = create_test_app()
with cls.app.app_context():
initialize_database(cls.app)
@classmethod
def tearDownClass(cls):
with cls.app.app_context():
# Use raw DROP to avoid circular FK dependency errors in SQLAlchemy's
# drop_all() (user <-> naming_template have mutual foreign keys)
from sqlalchemy import inspect, text
tables = inspect(db.engine).get_table_names()
with db.engine.connect() as conn:
if db.engine.name == 'postgresql':
for table in tables:
conn.execute(text(f'DROP TABLE IF EXISTS "{table}" CASCADE'))
else:
conn.execute(text('PRAGMA foreign_keys = OFF'))
for table in tables:
conn.execute(text(f'DROP TABLE IF EXISTS "{table}"'))
conn.execute(text('PRAGMA foreign_keys = ON'))
conn.commit()
def _get_table_names(self):
from sqlalchemy import inspect
inspector = inspect(db.engine)
return inspector.get_table_names()
def _get_column_names(self, table):
from sqlalchemy import inspect
inspector = inspect(db.engine)
return [col['name'] for col in inspector.get_columns(table)]
def test_core_tables_exist(self):
"""Verify that all core tables were created."""
with self.app.app_context():
tables = self._get_table_names()
expected_tables = [
'user', 'recording', 'transcript_chunk', 'tag',
'folder', 'share', 'internal_share', 'system_setting',
'speaker', 'processing_job', 'group', 'group_membership',
]
for table in expected_tables:
self.assertIn(table, tables, f"Missing table: {table}")
def test_user_migration_columns(self):
"""Verify columns added by migrations exist on the user table."""
with self.app.app_context():
columns = self._get_column_names('user')
expected = [
'id', 'username', 'email', 'password',
'transcription_language', 'output_language', 'ui_language',
'summary_prompt', 'extract_events', 'name', 'job_title',
'company', 'diarize', 'sso_provider', 'sso_subject',
'can_share_publicly', 'monthly_token_budget',
'monthly_transcription_budget', 'email_verified',
'auto_speaker_labelling', 'auto_summarization',
]
for col in expected:
self.assertIn(col, columns, f"Missing user column: {col}")
def test_recording_migration_columns(self):
"""Verify columns added by migrations exist on the recording table."""
with self.app.app_context():
columns = self._get_column_names('recording')
expected = [
'id', 'is_inbox', 'is_highlighted', 'mime_type',
'completed_at', 'processing_time_seconds', 'error_message',
'folder_id', 'audio_deleted_at', 'deletion_exempt',
'speaker_embeddings',
]
for col in expected:
self.assertIn(col, columns, f"Missing recording column: {col}")
def test_tag_migration_columns(self):
"""Verify columns added by migrations exist on the tag table."""
with self.app.app_context():
columns = self._get_column_names('tag')
expected = [
'id', 'protect_from_deletion', 'group_id',
'retention_days', 'auto_share_on_apply',
'naming_template_id', 'export_template_id',
]
for col in expected:
self.assertIn(col, columns, f"Missing tag column: {col}")
def test_system_settings_initialized(self):
"""Verify that default system settings were created."""
with self.app.app_context():
from src.models import SystemSetting
expected_keys = [
'transcript_length_limit', 'max_file_size_mb',
'asr_timeout_seconds', 'disable_auto_summarization',
'enable_folders',
]
for key in expected_keys:
setting = SystemSetting.query.filter_by(key=key).first()
self.assertIsNotNone(setting, f"Missing system setting: {key}")
def test_engine_type_matches_expectation(self):
"""Sanity check: confirm we're testing against the expected engine."""
with self.app.app_context():
uri = self.app.config['SQLALCHEMY_DATABASE_URI']
engine_name = db.engine.name
if uri.startswith('postgresql'):
self.assertEqual(engine_name, 'postgresql')
else:
self.assertEqual(engine_name, 'sqlite')
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,435 @@
"""
Test suite for the VIDEO_PASSTHROUGH_ASR feature.
Tests configuration, code path correctness, and interaction with VIDEO_RETENTION
across all entry points (processing pipeline, upload handler, file monitor, incognito).
Uses static analysis — no running server or real video files required.
Run with: python tests/test_video_passthrough.py
"""
import os
import re
import sys
import unittest
from pathlib import Path
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(TEST_DIR)
sys.path.insert(0, PROJECT_ROOT)
def read_file(rel_path):
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
return f.read()
# Cache file contents once — they don't change during the run
PROCESSING = read_file('src/tasks/processing.py')
RECORDINGS = read_file('src/api/recordings.py')
FILE_MONITOR = read_file('src/file_monitor.py')
APP_CONFIG = read_file('src/config/app_config.py')
ENV_EXAMPLE = read_file('config/env.transcription.example')
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def get_function_body(source, func_name):
"""Extract the body of a top-level function from source code."""
pattern = rf'^def {func_name}\('
lines = source.split('\n')
start = None
for i, line in enumerate(lines):
if re.match(pattern, line):
start = i
break
if start is None:
return ''
# Collect until next top-level def or class or EOF
body_lines = [lines[start]]
for line in lines[start + 1:]:
if line and not line[0].isspace() and (line.startswith('def ') or line.startswith('class ')):
break
body_lines.append(line)
return '\n'.join(body_lines)
def split_at_incognito(source):
"""Split processing.py into main and incognito sections."""
marker = 'def transcribe_incognito('
idx = source.find(marker)
if idx == -1:
return source, ''
return source[:idx], source[idx:]
PROCESSING_MAIN, PROCESSING_INCOGNITO = split_at_incognito(PROCESSING)
# ===========================================================================
# 1. Configuration
# ===========================================================================
class TestPassthroughConfig(unittest.TestCase):
"""VIDEO_PASSTHROUGH_ASR env var is defined and defaults to false."""
FILES_THAT_NEED_IT = [
('src/config/app_config.py', APP_CONFIG),
('src/tasks/processing.py', PROCESSING),
('src/api/recordings.py', RECORDINGS),
('src/file_monitor.py', FILE_MONITOR),
]
def test_defined_in_all_files(self):
for rel_path, content in self.FILES_THAT_NEED_IT:
with self.subTest(file=rel_path):
self.assertIn('VIDEO_PASSTHROUGH_ASR', content,
f"VIDEO_PASSTHROUGH_ASR missing from {rel_path}")
def test_default_is_false_everywhere(self):
for rel_path, content in self.FILES_THAT_NEED_IT:
match = re.search(
r"VIDEO_PASSTHROUGH_ASR\s*=\s*os\.environ\.get\('VIDEO_PASSTHROUGH_ASR',\s*'(\w+)'\)",
content
)
if match:
with self.subTest(file=rel_path):
self.assertEqual(match.group(1), 'false',
f"Default should be 'false' in {rel_path}")
def test_canonical_definition_in_app_config(self):
self.assertIn(
"VIDEO_PASSTHROUGH_ASR = os.environ.get('VIDEO_PASSTHROUGH_ASR', 'false').lower() == 'true'",
APP_CONFIG
)
def test_documented_in_env_example(self):
self.assertIn('VIDEO_PASSTHROUGH_ASR', ENV_EXAMPLE)
def test_processing_imports_from_config(self):
self.assertIn('VIDEO_PASSTHROUGH_ASR', PROCESSING)
# Should import from app_config, not read os.environ directly
self.assertIn('import', PROCESSING)
# Verify it's in an import line from app_config
import_lines = [l for l in PROCESSING.split('\n')
if 'from src.config.app_config import' in l]
found = any('VIDEO_PASSTHROUGH_ASR' in l for l in import_lines)
self.assertTrue(found, "processing.py should import VIDEO_PASSTHROUGH_ASR from app_config")
# ===========================================================================
# 2. Processing pipeline — main transcription path
# ===========================================================================
class TestProcessingMainPath(unittest.TestCase):
"""Test transcribe_with_connector() video passthrough code paths."""
def test_passthrough_branch_exists_before_retention(self):
"""VIDEO_PASSTHROUGH_ASR is checked before VIDEO_RETENTION in the is_video block."""
# Inside the `if is_video:` block, passthrough should be the first check
video_block_start = PROCESSING_MAIN.find('if is_video:')
self.assertNotEqual(video_block_start, -1)
after_video = PROCESSING_MAIN[video_block_start:]
passthrough_pos = after_video.find('if VIDEO_PASSTHROUGH_ASR:')
retention_pos = after_video.find('elif VIDEO_RETENTION:')
self.assertNotEqual(passthrough_pos, -1, "Missing VIDEO_PASSTHROUGH_ASR check in is_video block")
self.assertNotEqual(retention_pos, -1, "Missing elif VIDEO_RETENTION check")
self.assertLess(passthrough_pos, retention_pos,
"VIDEO_PASSTHROUGH_ASR should be checked before VIDEO_RETENTION")
def test_passthrough_does_not_call_extract_audio(self):
"""The passthrough branch must not call extract_audio_from_video."""
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
# Find the passthrough branch (from `if VIDEO_PASSTHROUGH_ASR:` to `elif VIDEO_RETENTION:`)
pt_start = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
pt_end = video_block.find('elif VIDEO_RETENTION:')
passthrough_block = video_block[pt_start:pt_end]
self.assertNotIn('extract_audio_from_video', passthrough_block,
"Passthrough branch should NOT extract audio")
def test_passthrough_keeps_original_filepath(self):
"""Passthrough sets actual_filepath = filepath (the original video)."""
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
pt_start = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
pt_end = video_block.find('elif VIDEO_RETENTION:')
passthrough_block = video_block[pt_start:pt_end]
self.assertIn('actual_filepath = filepath', passthrough_block)
def test_passthrough_with_retention_sets_recording_path(self):
"""When both passthrough and retention are on, recording.audio_path is set."""
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
pt_start = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
pt_end = video_block.find('elif VIDEO_RETENTION:')
passthrough_block = video_block[pt_start:pt_end]
self.assertIn('if VIDEO_RETENTION:', passthrough_block,
"Passthrough branch should conditionally handle retention")
self.assertIn('recording.audio_path = filepath', passthrough_block)
self.assertIn("mimetypes.guess_type(filepath)", passthrough_block)
def test_video_passthrough_active_flag_set(self):
"""video_passthrough_active flag is computed from is_video and VIDEO_PASSTHROUGH_ASR."""
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
PROCESSING_MAIN)
def test_conversion_skipped_when_passthrough(self):
"""convert_if_needed is inside an else block gated by video_passthrough_active."""
self.assertIn('if video_passthrough_active:', PROCESSING_MAIN)
# The conversion call should be in the else branch
flag_pos = PROCESSING_MAIN.find('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR')
after_flag = PROCESSING_MAIN[flag_pos:]
passthrough_if = after_flag.find('if video_passthrough_active:')
else_pos = after_flag.find('else:', passthrough_if)
convert_pos = after_flag.find('convert_if_needed(', else_pos)
self.assertGreater(convert_pos, else_pos,
"convert_if_needed should be in else branch after passthrough check")
def test_chunking_skipped_when_passthrough(self):
"""Chunking evaluates to False when video_passthrough_active."""
# Find the chunking decision area after the flag
flag_pos = PROCESSING_MAIN.find('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR')
after_flag = PROCESSING_MAIN[flag_pos:]
self.assertIn('if video_passthrough_active:\n should_chunk = False', after_flag)
def test_conversion_still_runs_for_non_passthrough(self):
"""convert_if_needed still runs when passthrough is off or file is audio."""
# The else branch of the passthrough check should contain convert_if_needed
self.assertIn('conversion_result = convert_if_needed(', PROCESSING_MAIN)
def test_chunking_still_evaluated_for_non_passthrough(self):
"""Chunking is still evaluated normally when passthrough is not active."""
self.assertIn('chunking_service.needs_chunking(actual_filepath, False, connector_specs)',
PROCESSING_MAIN)
# ===========================================================================
# 3. Processing pipeline — VIDEO_RETENTION paths still intact
# ===========================================================================
class TestRetentionNotBroken(unittest.TestCase):
"""Existing VIDEO_RETENTION behavior must be preserved."""
def test_retention_branch_still_extracts_audio(self):
"""elif VIDEO_RETENTION branch still calls extract_audio_from_video."""
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
ret_start = video_block.find('elif VIDEO_RETENTION:')
# Find next else: at the same indent level
after_ret = video_block[ret_start:]
else_pos = after_ret.find('\n else:')
retention_block = after_ret[:else_pos] if else_pos != -1 else after_ret[:500]
self.assertIn('extract_audio_from_video(filepath, cleanup_original=False)',
retention_block)
def test_default_branch_still_extracts_and_deletes(self):
"""The final else branch extracts audio with default cleanup (deletes video)."""
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
# The last else in the is_video block
self.assertIn('extract_audio_from_video(filepath)', video_block)
def test_temp_audio_cleanup_still_present(self):
"""Temp audio from retention is still cleaned up after transcription."""
self.assertIn('is_video and VIDEO_RETENTION and audio_filepath', PROCESSING_MAIN)
self.assertIn('Cleaned up temp audio from video retention', PROCESSING_MAIN)
# ===========================================================================
# 4. Incognito path
# ===========================================================================
class TestIncognitoPassthrough(unittest.TestCase):
"""Test passthrough in the incognito transcription path."""
def test_passthrough_flag_set_in_incognito(self):
"""video_passthrough_active is computed in incognito path."""
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
PROCESSING_INCOGNITO)
def test_passthrough_skips_extraction_in_incognito(self):
"""When passthrough is on, incognito skips extract_audio_from_video."""
# The passthrough branch logs and does NOT extract
self.assertIn('[Incognito] Video passthrough: sending original video to ASR',
PROCESSING_INCOGNITO)
def test_passthrough_skips_conversion_in_incognito(self):
"""When passthrough is on, incognito skips convert_if_needed."""
self.assertIn('[Incognito] Video passthrough: skipping codec conversion',
PROCESSING_INCOGNITO)
def test_passthrough_skips_chunking_in_incognito(self):
"""When passthrough is on, incognito chunking is False."""
body = PROCESSING_INCOGNITO
self.assertIn('if video_passthrough_active:\n should_chunk = False', body)
def test_incognito_does_not_reference_video_retention(self):
"""Incognito path should NOT reference VIDEO_RETENTION (no retention in incognito)."""
self.assertNotIn('VIDEO_RETENTION', PROCESSING_INCOGNITO)
def test_incognito_still_extracts_without_passthrough(self):
"""Without passthrough, incognito still extracts audio from video."""
self.assertIn('extract_audio_from_video(filepath, cleanup_original=False)',
PROCESSING_INCOGNITO)
def test_incognito_still_converts_without_passthrough(self):
"""Without passthrough, incognito still runs convert_if_needed."""
self.assertIn('convert_if_needed(', PROCESSING_INCOGNITO)
# ===========================================================================
# 5. Upload handler (recordings.py)
# ===========================================================================
class TestUploadHandlerPassthrough(unittest.TestCase):
"""Test recordings.py upload handler respects VIDEO_PASSTHROUGH_ASR."""
def test_skip_conversion_for_passthrough_video(self):
"""Upload handler skips conversion when passthrough or retention + video."""
self.assertIn('VIDEO_RETENTION or VIDEO_PASSTHROUGH_ASR) and has_video', RECORDINGS)
def test_extension_fallback_checks_passthrough(self):
"""Extension-based video detection also fires for VIDEO_PASSTHROUGH_ASR."""
self.assertIn('VIDEO_RETENTION or VIDEO_PASSTHROUGH_ASR', RECORDINGS)
def test_convert_if_needed_still_in_else(self):
"""convert_if_needed still runs for audio files or when both flags are off."""
self.assertIn('convert_if_needed(', RECORDINGS)
def test_passthrough_log_message(self):
"""Upload handler logs which mode caused the skip."""
self.assertIn("'VIDEO_PASSTHROUGH_ASR'", RECORDINGS)
# ===========================================================================
# 6. File monitor
# ===========================================================================
class TestFileMonitorPassthrough(unittest.TestCase):
"""Test file_monitor.py respects VIDEO_PASSTHROUGH_ASR."""
def test_passthrough_defined(self):
self.assertIn('VIDEO_PASSTHROUGH_ASR', FILE_MONITOR)
def test_skip_conversion_for_passthrough_or_retention(self):
"""File monitor skips conversion when passthrough or retention + video."""
self.assertIn('VIDEO_PASSTHROUGH_ASR or VIDEO_RETENTION) and has_video', FILE_MONITOR)
def test_convert_if_needed_in_else_branch(self):
"""convert_if_needed is in the else branch, not inside the skip block."""
lines = FILE_MONITOR.split('\n')
in_skip_block = False
found_else = False
for i, line in enumerate(lines):
if 'VIDEO_PASSTHROUGH_ASR or VIDEO_RETENTION) and has_video' in line:
in_skip_block = True
elif in_skip_block and line.strip().startswith('else:'):
in_skip_block = False
found_else = True
elif in_skip_block and 'convert_if_needed' in line:
self.fail(f"convert_if_needed inside skip block at line {i + 1}")
self.assertTrue(found_else, "Should have else branch after passthrough/retention skip")
def test_log_distinguishes_passthrough_from_retention(self):
"""Log message indicates whether passthrough or retention caused the skip."""
self.assertIn("'passthrough'", FILE_MONITOR)
self.assertIn("'retention'", FILE_MONITOR)
# ===========================================================================
# 7. Audio files unaffected by passthrough
# ===========================================================================
class TestAudioUnaffected(unittest.TestCase):
"""VIDEO_PASSTHROUGH_ASR must only affect video files, never audio."""
def test_passthrough_flag_gated_on_is_video(self):
"""video_passthrough_active is always `is_video and VIDEO_PASSTHROUGH_ASR`."""
# Main path
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
PROCESSING_MAIN)
# Incognito path
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
PROCESSING_INCOGNITO)
def test_upload_handler_gated_on_has_video(self):
"""Upload handler skip is gated on `has_video`."""
self.assertIn('and has_video', RECORDINGS)
def test_file_monitor_gated_on_has_video(self):
"""File monitor skip is gated on `has_video`."""
self.assertIn('and has_video', FILE_MONITOR)
# ===========================================================================
# 8. Documentation
# ===========================================================================
class TestDocumentation(unittest.TestCase):
"""VIDEO_PASSTHROUGH_ASR is documented in all relevant places."""
DOC_FILES = [
'config/env.transcription.example',
'docs/admin-guide/system-settings.md',
'docs/features.md',
'docs/getting-started/installation.md',
]
def test_documented_in_all_relevant_files(self):
for rel_path in self.DOC_FILES:
content = read_file(rel_path)
with self.subTest(file=rel_path):
self.assertIn('VIDEO_PASSTHROUGH_ASR', content,
f"VIDEO_PASSTHROUGH_ASR missing from {rel_path}")
def test_env_example_commented_out_by_default(self):
"""The env example has the option commented out (opt-in)."""
self.assertIn('# VIDEO_PASSTHROUGH_ASR=false', ENV_EXAMPLE)
def test_docs_warn_about_asr_compatibility(self):
"""Docs warn that standard APIs will reject video input."""
system_settings = read_file('docs/admin-guide/system-settings.md')
installation = read_file('docs/getting-started/installation.md')
self.assertIn('reject', system_settings.lower())
self.assertIn('reject', installation.lower())
# ===========================================================================
# 9. Interaction matrix — structural verification
# ===========================================================================
class TestInteractionMatrix(unittest.TestCase):
"""
Verify the 3-way branch structure in processing.py:
if VIDEO_PASSTHROUGH_ASR → passthrough
elif VIDEO_RETENTION → retention
else → default extraction
"""
def test_three_way_branch_in_main_path(self):
"""Main path has if/elif/else for passthrough/retention/default."""
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
# All three branches present in order
pt_pos = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
ret_pos = video_block.find('elif VIDEO_RETENTION:')
else_pos = video_block.find('\n else:', ret_pos)
self.assertNotEqual(pt_pos, -1)
self.assertNotEqual(ret_pos, -1)
self.assertNotEqual(else_pos, -1)
self.assertLess(pt_pos, ret_pos)
self.assertLess(ret_pos, else_pos)
def test_incognito_two_way_branch(self):
"""Incognito has if/else for passthrough/extract (no retention)."""
video_block = PROCESSING_INCOGNITO[PROCESSING_INCOGNITO.find('if is_video:'):]
pt_pos = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
else_pos = video_block.find('\n else:', pt_pos)
self.assertNotEqual(pt_pos, -1)
self.assertNotEqual(else_pos, -1)
# No VIDEO_RETENTION in incognito
incognito_video_block = video_block[:500]
self.assertNotIn('VIDEO_RETENTION', incognito_video_block)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -0,0 +1,370 @@
"""
Test suite for the VIDEO_RETENTION feature.
Tests code paths, configuration, and template correctness for video retention.
Does NOT require a running server or real video files - uses static analysis
and mocking where possible.
Run with: python tests/test_video_retention.py
"""
import os
import re
import sys
import json
import unittest
from unittest.mock import patch, MagicMock
from pathlib import Path
# Find project root
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(TEST_DIR)
sys.path.insert(0, PROJECT_ROOT)
class TestVideoRetentionConfig(unittest.TestCase):
"""Test that VIDEO_RETENTION env var is read correctly everywhere."""
ALL_FILES = [
'src/app.py',
'src/tasks/processing.py',
'src/api/system.py',
'src/api/recordings.py',
'src/file_monitor.py',
]
def _read_file(self, rel_path):
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
return f.read()
def test_env_var_read_in_all_entry_points(self):
"""VIDEO_RETENTION env var is read in all files that need it."""
for rel_path in self.ALL_FILES:
content = self._read_file(rel_path)
self.assertIn("VIDEO_RETENTION", content, f"VIDEO_RETENTION missing from {rel_path}")
def test_exposed_in_api_config(self):
"""VIDEO_RETENTION is exposed in the /api/config response."""
content = self._read_file('src/api/system.py')
self.assertIn("'video_retention': VIDEO_RETENTION", content)
def test_default_is_false(self):
"""All VIDEO_RETENTION reads default to 'false'."""
for rel_path in self.ALL_FILES:
content = self._read_file(rel_path)
match = re.search(r"VIDEO_RETENTION\s*=\s*os\.environ\.get\('VIDEO_RETENTION',\s*'(\w+)'\)", content)
if match:
self.assertEqual(match.group(1), 'false', f"Default should be 'false' in {rel_path}")
class TestProcessingPipelineVideoRetention(unittest.TestCase):
"""Test processing.py video retention code paths via static analysis."""
@classmethod
def setUpClass(cls):
with open(os.path.join(PROJECT_ROOT, 'src/tasks/processing.py'), 'r') as f:
cls.content = f.read()
def test_video_retention_true_keeps_original(self):
"""When VIDEO_RETENTION=True, recording.audio_path is set to original filepath."""
# The VIDEO_RETENTION=True branch should set recording.audio_path = filepath
self.assertIn('recording.audio_path = filepath', self.content)
def test_video_retention_true_extracts_without_cleanup(self):
"""When VIDEO_RETENTION=True, extract_audio_from_video is called with cleanup_original=False."""
self.assertIn('extract_audio_from_video(filepath, cleanup_original=False)', self.content)
def test_video_retention_false_extracts_with_cleanup(self):
"""When VIDEO_RETENTION=False, extract_audio_from_video is called with default cleanup."""
self.assertIn('extract_audio_from_video(filepath)', self.content)
def test_temp_audio_cleanup_after_transcription(self):
"""Temp audio from video retention is cleaned up after transcription."""
self.assertIn('is_video and VIDEO_RETENTION and audio_filepath', self.content)
self.assertIn('Cleaned up temp audio from video retention', self.content)
def test_audio_filepath_initialized_to_none(self):
"""audio_filepath is initialized to None before the is_video check."""
# Find the initialization line
self.assertIn('audio_filepath = None', self.content)
def test_video_mime_type_set_for_retention(self):
"""When retaining video, mime_type reflects actual video type."""
self.assertIn("mimetypes.guess_type(filepath)[0] or 'video/mp4'", self.content)
def test_duration_uses_recording_audio_path(self):
"""Duration lookup uses recording.audio_path (always valid), not filepath."""
self.assertIn('chunking_service.get_audio_duration(recording.audio_path)', self.content)
# Should NOT use bare filepath for duration (pre-existing bug was fixed)
self.assertNotIn('chunking_service.get_audio_duration(filepath)', self.content)
class TestUploadHandlerVideoRetention(unittest.TestCase):
"""Test recordings.py upload handler video retention code paths."""
@classmethod
def setUpClass(cls):
with open(os.path.join(PROJECT_ROOT, 'src/api/recordings.py'), 'r') as f:
cls.content = f.read()
def test_upload_handler_skips_conversion_for_video_retention(self):
"""Upload handler skips convert_if_needed for videos when retention is on."""
self.assertIn('VIDEO_RETENTION and has_video', self.content)
self.assertIn('skipping conversion', self.content)
def test_upload_handler_has_video_from_codec_info(self):
"""Upload handler reads has_video from codec_info probe."""
self.assertIn("has_video = codec_info.get('has_video', False)", self.content)
def test_convert_if_needed_still_in_else_branch(self):
"""convert_if_needed still runs for non-video files or when retention is off."""
self.assertIn('convert_if_needed(', self.content)
def test_processing_pipeline_still_converts_audio(self):
"""Processing pipeline runs convert_if_needed on extracted audio (the safety net)."""
proc_content = open(os.path.join(PROJECT_ROOT, 'src/tasks/processing.py')).read()
# After the video extraction block, convert_if_needed runs on actual_filepath
self.assertIn('conversion_result = convert_if_needed(\n'
' filepath=actual_filepath,', proc_content)
class TestFileMonitorVideoRetention(unittest.TestCase):
"""Test file_monitor.py video retention code paths."""
@classmethod
def setUpClass(cls):
with open(os.path.join(PROJECT_ROOT, 'src/file_monitor.py'), 'r') as f:
cls.content = f.read()
def test_video_retention_skips_conversion(self):
"""When VIDEO_RETENTION=True and has_video=True, convert_if_needed is skipped."""
# Should have the guard: if VIDEO_RETENTION and has_video: ... skip conversion
self.assertIn('VIDEO_RETENTION and has_video', self.content)
self.assertIn('skipping conversion', self.content)
def test_no_double_extraction(self):
"""File monitor does NOT call convert_if_needed for videos when retention is on."""
# The convert_if_needed call should be in the else branch
lines = self.content.split('\n')
in_retention_skip_block = False
found_convert_in_else = False
for i, line in enumerate(lines):
if 'VIDEO_RETENTION and has_video' in line and 'if' in line:
in_retention_skip_block = True
elif in_retention_skip_block and 'else:' in line:
in_retention_skip_block = False
found_convert_in_else = True
elif in_retention_skip_block and 'convert_if_needed' in line:
self.fail(f"convert_if_needed called inside VIDEO_RETENTION skip block at line {i+1}")
self.assertTrue(found_convert_in_else, "Should have else branch after video retention skip")
def test_no_video_retention_param_in_convert_call(self):
"""convert_if_needed should NOT receive a video_retention parameter."""
# Ensure the old video_retention parameter isn't being passed
self.assertNotIn('video_retention=VIDEO_RETENTION', self.content)
class TestAudioConversionNotModified(unittest.TestCase):
"""Verify audio_conversion.py was fully reverted (no video_retention parameter)."""
@classmethod
def setUpClass(cls):
with open(os.path.join(PROJECT_ROOT, 'src/utils/audio_conversion.py'), 'r') as f:
cls.content = f.read()
def test_no_video_retention_parameter(self):
"""convert_if_needed should not have a video_retention parameter."""
self.assertNotIn('video_retention', self.content)
def test_no_should_delete_original(self):
"""No should_delete_original variable should exist."""
self.assertNotIn('should_delete_original', self.content)
class TestSendFileConditional(unittest.TestCase):
"""Test that send_file calls use conditional=True for range request support."""
def _read_file(self, rel_path):
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
return f.read()
def test_recordings_streaming_has_conditional(self):
"""Streaming send_file in recordings.py has conditional=True."""
content = self._read_file('src/api/recordings.py')
# Find the non-download send_file call
self.assertIn('send_file(recording.audio_path, conditional=True)', content)
def test_recordings_download_has_conditional(self):
"""Download send_file in recordings.py has conditional=True."""
content = self._read_file('src/api/recordings.py')
self.assertIn('as_attachment=True, download_name=filename, conditional=True', content)
def test_shares_has_conditional(self):
"""send_file in shares.py has conditional=True."""
content = self._read_file('src/api/shares.py')
self.assertIn('send_file(recording.audio_path, conditional=True)', content)
class TestFrontendTemplates(unittest.TestCase):
"""Test that frontend templates correctly switch between video and audio."""
TEMPLATE_FILES = [
'templates/components/detail/desktop-right-panel.html',
'templates/components/detail/audio-player.html',
'templates/modals/speaker-modal.html',
'templates/share.html',
]
def _read_template(self, rel_path):
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
return f.read()
def test_all_templates_use_dynamic_component(self):
"""All player templates use <component :is> for video/audio switching."""
for tmpl in self.TEMPLATE_FILES:
content = self._read_template(tmpl)
self.assertIn("<component :is=", content, f"Missing dynamic component in {tmpl}")
self.assertIn("startsWith('video/')", content, f"Missing video/ check in {tmpl}")
self.assertIn("</component>", content, f"Missing </component> in {tmpl}")
def test_no_bare_audio_elements_in_main_players(self):
"""Main player templates should not have bare <audio elements (replaced by component)."""
for tmpl in self.TEMPLATE_FILES:
content = self._read_template(tmpl)
# Count <audio and <component :is occurrences
audio_count = content.count('<audio ')
component_count = content.count('<component :is=')
# Each template should have component :is but no bare <audio for the main player
self.assertGreater(component_count, 0, f"No <component :is> in {tmpl}")
# Desktop right panel and audio player should have 0 bare audio tags
if 'desktop-right-panel' in tmpl or 'audio-player' in tmpl:
self.assertEqual(audio_count, 0, f"Unexpected bare <audio> in {tmpl}")
def test_video_element_gets_visible_styling(self):
"""When mime_type is video/, the element should be visible (not hidden)."""
for tmpl in self.TEMPLATE_FILES:
content = self._read_template(tmpl)
# Should have conditional class that shows video and hides audio
self.assertIn("'w-full rounded-lg", content, f"Missing video styling in {tmpl}")
self.assertIn("'hidden'", content, f"Missing hidden fallback for audio in {tmpl}")
def test_template_div_balance(self):
"""Verify player-specific templates have balanced div tags."""
# Only check templates we fully control (not share.html which has pre-existing imbalance)
balanced_templates = [
'templates/components/detail/desktop-right-panel.html',
'templates/components/detail/audio-player.html',
'templates/modals/speaker-modal.html',
]
for tmpl in balanced_templates:
content = self._read_template(tmpl)
opens = content.count('<div')
closes = content.count('</div>')
self.assertEqual(opens, closes, f"Unbalanced divs in {tmpl}: {opens} opens, {closes} closes")
class TestLocalization(unittest.TestCase):
"""Test that video retention localization keys exist in all locale files."""
LOCALE_DIR = os.path.join(PROJECT_ROOT, 'static', 'locales')
def test_video_retained_key_in_all_locales(self):
"""upload.videoRetained key exists in all locale files."""
locale_files = [f for f in os.listdir(self.LOCALE_DIR) if f.endswith('.json')]
self.assertGreater(len(locale_files), 0, "No locale files found")
for locale_file in locale_files:
filepath = os.path.join(self.LOCALE_DIR, locale_file)
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.assertIn('upload', data, f"No 'upload' section in {locale_file}")
self.assertIn('videoRetained', data['upload'],
f"Missing 'videoRetained' key in upload section of {locale_file}")
self.assertIsInstance(data['upload']['videoRetained'], str,
f"'videoRetained' should be a string in {locale_file}")
self.assertGreater(len(data['upload']['videoRetained']), 0,
f"'videoRetained' is empty in {locale_file}")
def test_locale_files_are_valid_json(self):
"""All locale files are valid JSON."""
locale_files = [f for f in os.listdir(self.LOCALE_DIR) if f.endswith('.json')]
for locale_file in locale_files:
filepath = os.path.join(self.LOCALE_DIR, locale_file)
try:
with open(filepath, 'r', encoding='utf-8') as f:
json.load(f)
except json.JSONDecodeError as e:
self.fail(f"Invalid JSON in {locale_file}: {e}")
class TestVideoRetentionMatrix(unittest.TestCase):
"""
Test the complete 2x2 matrix of (VIDEO_RETENTION x is_video) scenarios
by analyzing the code flow statically.
"""
def _read_file(self, rel_path):
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
return f.read()
def test_processing_has_both_branches(self):
"""processing.py has both VIDEO_RETENTION=True and False branches for video."""
content = self._read_file('src/tasks/processing.py')
# Should have if VIDEO_RETENTION: ... else: ... inside if is_video:
self.assertIn('if VIDEO_RETENTION:', content)
# After the retention block, should have else for the default behavior
lines = content.split('\n')
found_retention_if = False
found_else_after = False
for line in lines:
if 'if VIDEO_RETENTION:' in line:
found_retention_if = True
elif found_retention_if and line.strip().startswith('else:'):
found_else_after = True
break
self.assertTrue(found_else_after, "Missing else branch after VIDEO_RETENTION check in processing.py")
def test_file_monitor_has_both_branches(self):
"""file_monitor.py has both video retention skip and normal conversion paths."""
content = self._read_file('src/file_monitor.py')
self.assertIn('VIDEO_RETENTION and has_video', content)
# convert_if_needed should still exist in the else path
self.assertIn('convert_if_needed(', content)
def test_incognito_not_affected(self):
"""Incognito processing path should NOT reference VIDEO_RETENTION."""
content = self._read_file('src/tasks/processing.py')
# Find the incognito section (marked with [Incognito])
incognito_section = content[content.find('[Incognito]'):]
# VIDEO_RETENTION should not appear in incognito section
# (incognito always strips video per the plan)
self.assertNotIn('VIDEO_RETENTION', incognito_section,
"VIDEO_RETENTION should not be referenced in incognito processing")
def test_all_three_entry_points_skip_for_video_retention(self):
"""All entry points (upload, file monitor, processing) handle VIDEO_RETENTION."""
for rel_path, marker in [
('src/api/recordings.py', 'VIDEO_RETENTION and has_video'),
('src/file_monitor.py', 'VIDEO_RETENTION and has_video'),
('src/tasks/processing.py', 'if VIDEO_RETENTION:'),
]:
content = self._read_file(rel_path)
self.assertIn(marker, content, f"Missing video retention guard in {rel_path}")
def test_convert_if_needed_always_runs_on_transcription_audio(self):
"""Processing pipeline always runs convert_if_needed on audio before transcription."""
content = self._read_file('src/tasks/processing.py')
# The convert_if_needed call on actual_filepath happens AFTER the video
# extraction block, regardless of VIDEO_RETENTION setting
video_block_pos = content.find('if is_video:')
convert_pos = content.find('convert_if_needed(\n filepath=actual_filepath,')
self.assertGreater(convert_pos, video_block_pos,
"convert_if_needed must run after video extraction block")
if __name__ == '__main__':
unittest.main(verbosity=2)