Initial release: DictIA v0.8.14-alpha (fork de Speakr, AGPL-3.0)
This commit is contained in:
971
tests/test_api_v1_speakers.py
Normal file
971
tests/test_api_v1_speakers.py
Normal file
@@ -0,0 +1,971 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for Speaker API v1 endpoints.
|
||||
|
||||
Covers:
|
||||
- PUT /recordings/<id>/speakers/assign (17 tests)
|
||||
- POST /recordings/<id>/speakers/identify (10 tests)
|
||||
- PUT /settings/auto-summarization (5 tests)
|
||||
- Regression for GET /speakers and GET /recordings/<id>/speakers (2 tests)
|
||||
|
||||
Pattern follows tests/test_api_v1_upload.py — standalone, no pytest fixtures.
|
||||
"""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Add parent directory so we can import the app
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.app import app, db
|
||||
from src.models import User, APIToken, Recording, Speaker
|
||||
from src.utils.token_auth import hash_token
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test data constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_TRANSCRIPTION_JSON = json.dumps([
|
||||
{"speaker": "SPEAKER_00", "sentence": "Hi, I'm Alice."},
|
||||
{"speaker": "SPEAKER_01", "sentence": "Hello Alice, I'm Bob."},
|
||||
{"speaker": "SPEAKER_00", "sentence": "Nice to meet you, Bob."},
|
||||
])
|
||||
|
||||
SAMPLE_TRANSCRIPTION_TEXT = (
|
||||
"[SPEAKER_00]: Hi, I'm Alice.\n"
|
||||
"[SPEAKER_01]: Hello Alice, I'm Bob.\n"
|
||||
"[SPEAKER_00]: Nice to meet you, Bob."
|
||||
)
|
||||
|
||||
SAMPLE_EMBEDDINGS = json.dumps({
|
||||
"SPEAKER_00": [0.1] * 256,
|
||||
"SPEAKER_01": [0.2] * 256,
|
||||
})
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_or_create_test_user(suffix=""):
|
||||
"""Get or create a test user. Returns (user, created_bool)."""
|
||||
username = f"speaker_test_user{suffix}"
|
||||
user = User.query.filter_by(username=username).first()
|
||||
created = False
|
||||
if not user:
|
||||
user = User(
|
||||
username=username,
|
||||
email=f"{username}@local.test",
|
||||
name="Test User" if not suffix else None,
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
created = True
|
||||
return user, created
|
||||
|
||||
|
||||
def _create_api_token(user):
|
||||
"""Create a fresh API token. Returns (token_record, plaintext)."""
|
||||
plaintext = f"test-token-{secrets.token_urlsafe(16)}"
|
||||
token = APIToken(
|
||||
user_id=user.id,
|
||||
token_hash=hash_token(plaintext),
|
||||
name="test-api-token",
|
||||
)
|
||||
db.session.add(token)
|
||||
db.session.commit()
|
||||
return token, plaintext
|
||||
|
||||
|
||||
def _create_test_recording(user, transcription=None, speaker_embeddings=None, status="COMPLETED"):
|
||||
"""Create a Recording owned by *user*."""
|
||||
rec = Recording(
|
||||
user_id=user.id,
|
||||
title="Test Recording",
|
||||
status=status,
|
||||
transcription=transcription,
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
db.session.add(rec)
|
||||
db.session.commit()
|
||||
return rec
|
||||
|
||||
|
||||
def _create_test_speaker(user, name="Alice"):
|
||||
"""Create a Speaker owned by *user*."""
|
||||
speaker = Speaker(name=name, user_id=user.id)
|
||||
db.session.add(speaker)
|
||||
db.session.commit()
|
||||
return speaker
|
||||
|
||||
|
||||
def _cleanup(*objects):
|
||||
"""Delete DB objects in reverse order, committing once."""
|
||||
for obj in reversed(objects):
|
||||
try:
|
||||
db.session.delete(obj)
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
try:
|
||||
merged = db.session.merge(obj)
|
||||
db.session.delete(merged)
|
||||
except Exception:
|
||||
pass
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Group 1: PUT /recordings/<id>/speakers/assign (17 tests)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def test_assign_no_auth():
|
||||
"""No token -> 302 redirect (Flask-Login)."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
json={"speaker_map": {}})
|
||||
assert resp.status_code in (302, 401), f"Expected 302/401, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_recording_not_found():
|
||||
"""Nonexistent recording ID -> 404."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put("/api/v1/recordings/999999/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {}})
|
||||
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_wrong_user_recording():
|
||||
"""Other user's recording -> 403."""
|
||||
with app.app_context():
|
||||
owner, co = _get_or_create_test_user("_owner")
|
||||
other, cu = _get_or_create_test_user("_other")
|
||||
token_rec, token = _create_api_token(other)
|
||||
rec = _create_test_recording(owner, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": "Alice"}})
|
||||
assert resp.status_code == 403, f"Expected 403, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(other)
|
||||
if co:
|
||||
_cleanup(owner)
|
||||
|
||||
|
||||
def test_assign_missing_speaker_map():
|
||||
"""Body {} -> 400 'speaker_map is required'."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert "speaker_map" in body.get("error", "").lower(), f"Unexpected error: {body}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_invalid_speaker_map_type():
|
||||
"""speaker_map: 'string' -> 400."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": "not a dict"})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_string_value_json_transcript():
|
||||
"""Happy path: string names update JSON segments + participants."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": "Alice", "SPEAKER_01": "Bob"}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert body.get("success") is True
|
||||
# Verify participants
|
||||
participants = body["recording"]["participants"]
|
||||
assert "Alice" in participants and "Bob" in participants
|
||||
# Verify transcription was updated
|
||||
db.session.refresh(rec)
|
||||
segments = json.loads(rec.transcription)
|
||||
assert segments[0]["speaker"] == "Alice"
|
||||
assert segments[1]["speaker"] == "Bob"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_object_value_with_name():
|
||||
"""Happy path: {name, isMe} object format."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {
|
||||
"SPEAKER_00": {"name": "Alice", "isMe": False},
|
||||
}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
db.session.refresh(rec)
|
||||
segments = json.loads(rec.transcription)
|
||||
assert segments[0]["speaker"] == "Alice"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_is_me_flag_with_user_name():
|
||||
"""isMe: true resolves to user.name."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user() # user.name == "Test User"
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {
|
||||
"SPEAKER_00": {"name": "", "isMe": True},
|
||||
}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
db.session.refresh(rec)
|
||||
segments = json.loads(rec.transcription)
|
||||
assert segments[0]["speaker"] == "Test User", f"Got {segments[0]['speaker']}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_is_me_flag_without_user_name():
|
||||
"""isMe: true falls back to 'Me' when user.name is None."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user("_noname")
|
||||
# Ensure user.name is None
|
||||
user.name = None
|
||||
db.session.commit()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {
|
||||
"SPEAKER_00": {"name": "", "isMe": True},
|
||||
}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
db.session.refresh(rec)
|
||||
segments = json.loads(rec.transcription)
|
||||
assert segments[0]["speaker"] == "Me", f"Got {segments[0]['speaker']}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_plain_text_transcript():
|
||||
"""Replaces [SPEAKER_XX] in plain text format."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_TEXT)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": "Alice", "SPEAKER_01": "Bob"}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
db.session.refresh(rec)
|
||||
assert "[Alice]" in rec.transcription
|
||||
assert "[Bob]" in rec.transcription
|
||||
assert "[SPEAKER_00]" not in rec.transcription
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_speaker_xx_filtered_from_participants():
|
||||
"""Unresolved SPEAKER_XX labels excluded from participants."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
# Only assign one speaker - SPEAKER_01 stays unresolved
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": "Alice"}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
participants = body["recording"]["participants"]
|
||||
assert "SPEAKER_01" not in participants, f"SPEAKER_01 should be filtered: {participants}"
|
||||
assert "Alice" in participants
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_invalid_value_type():
|
||||
"""Array value -> 400 'Invalid value type'."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": [1, 2, 3]}})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert "invalid value type" in body.get("error", "").lower(), f"Unexpected: {body}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_empty_speaker_map():
|
||||
"""Empty speaker_map {} -> 200 with no changes."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert body.get("success") is True
|
||||
# Transcription should be unchanged
|
||||
db.session.refresh(rec)
|
||||
segments = json.loads(rec.transcription)
|
||||
assert segments[0]["speaker"] == "SPEAKER_00"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_regenerate_summary():
|
||||
"""regenerate_summary: true -> job_queue.enqueue called, summary_queued: true."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
mock_jq = MagicMock()
|
||||
mock_jq.enqueue = MagicMock(return_value="job-123")
|
||||
with patch("src.services.job_queue.job_queue", mock_jq):
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={
|
||||
"speaker_map": {"SPEAKER_00": "Alice"},
|
||||
"regenerate_summary": True,
|
||||
})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert body.get("summary_queued") is True
|
||||
mock_jq.enqueue.assert_called_once()
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_embeddings_updated():
|
||||
"""With speaker_embeddings -> update_speaker_embedding called, counts returned."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(
|
||||
user,
|
||||
transcription=SAMPLE_TRANSCRIPTION_JSON,
|
||||
speaker_embeddings=SAMPLE_EMBEDDINGS,
|
||||
)
|
||||
speaker = _create_test_speaker(user, "Alice")
|
||||
client = app.test_client()
|
||||
try:
|
||||
mock_update = MagicMock()
|
||||
mock_snippets = MagicMock(return_value=2)
|
||||
with patch("src.services.speaker_embedding_matcher.update_speaker_embedding", mock_update), \
|
||||
patch("src.services.speaker_snippets.create_speaker_snippets", mock_snippets):
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": "Alice"}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert body.get("embeddings_updated") >= 1, f"embeddings_updated: {body}"
|
||||
mock_update.assert_called()
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, speaker, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_no_transcription():
|
||||
"""Recording without transcription -> speakers applied to empty content gracefully."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=None)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": "Alice"}})
|
||||
# Should succeed (or at least not 500)
|
||||
assert resp.status_code in (200, 400), f"Expected 200/400, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_assign_whitespace_name_trimmed():
|
||||
"""Names with leading/trailing whitespace get trimmed."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put(f"/api/v1/recordings/{rec.id}/speakers/assign",
|
||||
headers={"X-API-Token": token},
|
||||
json={"speaker_map": {"SPEAKER_00": " Alice "}})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
db.session.refresh(rec)
|
||||
segments = json.loads(rec.transcription)
|
||||
assert segments[0]["speaker"] == "Alice", f"Name not trimmed: '{segments[0]['speaker']}'"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Group 2: POST /recordings/<id>/speakers/identify (10 tests)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def test_identify_no_auth():
|
||||
"""No token -> 302."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify")
|
||||
assert resp.status_code in (302, 401), f"Expected 302/401, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_recording_not_found():
|
||||
"""Nonexistent ID -> 404."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.post("/api/v1/recordings/999999/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_wrong_user_recording():
|
||||
"""Other user's recording -> 403."""
|
||||
with app.app_context():
|
||||
owner, co = _get_or_create_test_user("_id_owner")
|
||||
other, cu = _get_or_create_test_user("_id_other")
|
||||
token_rec, token = _create_api_token(other)
|
||||
rec = _create_test_recording(owner, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 403, f"Expected 403, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(other)
|
||||
if co:
|
||||
_cleanup(owner)
|
||||
|
||||
|
||||
def test_identify_no_transcription():
|
||||
"""No transcription -> 400."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=None)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_non_json_transcription():
|
||||
"""Plain text -> 400."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_TEXT)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_json_but_not_list():
|
||||
"""Dict JSON -> 400."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=json.dumps({"key": "value"}))
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_happy_path():
|
||||
"""Mock LLM returns names -> 200 with speaker_map."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
# Build a mock LLM completion response
|
||||
mock_completion = MagicMock()
|
||||
mock_completion.choices = [MagicMock()]
|
||||
mock_completion.choices[0].message.content = json.dumps({
|
||||
"SPEAKER_00": "Alice",
|
||||
"SPEAKER_01": "Bob",
|
||||
})
|
||||
|
||||
with patch("src.services.llm.call_llm_completion", return_value=mock_completion), \
|
||||
patch("src.models.system.SystemSetting") as mock_ss:
|
||||
mock_ss.get_setting.return_value = 30000
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert body.get("success") is True
|
||||
sm = body.get("speaker_map", {})
|
||||
assert sm.get("SPEAKER_00") == "Alice"
|
||||
assert sm.get("SPEAKER_01") == "Bob"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_post_processing_unknown_values():
|
||||
"""'Unknown'/'N/A' cleared to ''."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
mock_completion = MagicMock()
|
||||
mock_completion.choices = [MagicMock()]
|
||||
mock_completion.choices[0].message.content = json.dumps({
|
||||
"SPEAKER_00": "Unknown",
|
||||
"SPEAKER_01": "N/A",
|
||||
})
|
||||
|
||||
with patch("src.services.llm.call_llm_completion", return_value=mock_completion), \
|
||||
patch("src.models.system.SystemSetting") as mock_ss:
|
||||
mock_ss.get_setting.return_value = 30000
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
sm = body.get("speaker_map", {})
|
||||
assert sm.get("SPEAKER_00") == "", f"Expected empty, got {sm.get('SPEAKER_00')}"
|
||||
assert sm.get("SPEAKER_01") == "", f"Expected empty, got {sm.get('SPEAKER_01')}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_no_speakers_in_transcript():
|
||||
"""Segments without speaker field -> 400."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
no_speakers = json.dumps([{"sentence": "Hello"}, {"sentence": "World"}])
|
||||
rec = _create_test_recording(user, transcription=no_speakers)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_identify_llm_error():
|
||||
"""LLM raises exception -> 500."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
with patch("src.services.llm.call_llm_completion",
|
||||
side_effect=RuntimeError("LLM down")), \
|
||||
patch("src.models.system.SystemSetting") as mock_ss:
|
||||
mock_ss.get_setting.return_value = 30000
|
||||
resp = client.post(f"/api/v1/recordings/{rec.id}/speakers/identify",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 500, f"Expected 500, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Group 3: PUT /settings/auto-summarization (5 tests)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def test_auto_summarization_no_auth():
|
||||
"""No token -> 302."""
|
||||
with app.app_context():
|
||||
client = app.test_client()
|
||||
resp = client.put("/api/v1/settings/auto-summarization",
|
||||
json={"enabled": True})
|
||||
assert resp.status_code in (302, 401), f"Expected 302/401, got {resp.status_code}"
|
||||
return True
|
||||
|
||||
|
||||
def test_auto_summarization_missing_enabled():
|
||||
"""Body {} -> 400."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put("/api/v1/settings/auto-summarization",
|
||||
headers={"X-API-Token": token},
|
||||
json={})
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert "enabled" in body.get("error", "").lower(), f"Unexpected: {body}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_auto_summarization_invalid_json():
|
||||
"""Non-JSON body -> 400."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put("/api/v1/settings/auto-summarization",
|
||||
headers={"X-API-Token": token,
|
||||
"Content-Type": "application/json"},
|
||||
data="not valid json")
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_auto_summarization_enable():
|
||||
"""enabled: true -> updates user, returns true."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
user.auto_summarization = False
|
||||
db.session.commit()
|
||||
token_rec, token = _create_api_token(user)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put("/api/v1/settings/auto-summarization",
|
||||
headers={"X-API-Token": token},
|
||||
json={"enabled": True})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert body.get("auto_summarization") is True
|
||||
db.session.refresh(user)
|
||||
assert user.auto_summarization is True
|
||||
return True
|
||||
finally:
|
||||
_cleanup(token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_auto_summarization_disable():
|
||||
"""enabled: false -> updates user, returns false."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
user.auto_summarization = True
|
||||
db.session.commit()
|
||||
token_rec, token = _create_api_token(user)
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.put("/api/v1/settings/auto-summarization",
|
||||
headers={"X-API-Token": token},
|
||||
json={"enabled": False})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
assert body.get("auto_summarization") is False
|
||||
db.session.refresh(user)
|
||||
assert user.auto_summarization is False
|
||||
return True
|
||||
finally:
|
||||
_cleanup(token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Group 4: Regression tests (2 tests)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def test_regression_get_speakers_list():
|
||||
"""GET /speakers still returns user's speakers."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
speaker = _create_test_speaker(user, "Regression Speaker")
|
||||
client = app.test_client()
|
||||
try:
|
||||
resp = client.get("/api/v1/speakers",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
names = [s["name"] for s in body.get("speakers", [])]
|
||||
assert "Regression Speaker" in names, f"Speaker not found: {names}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(speaker, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
def test_regression_get_recording_speakers():
|
||||
"""GET /recordings/<id>/speakers still returns transcript speakers."""
|
||||
with app.app_context():
|
||||
user, cu = _get_or_create_test_user()
|
||||
token_rec, token = _create_api_token(user)
|
||||
rec = _create_test_recording(user, transcription=SAMPLE_TRANSCRIPTION_JSON)
|
||||
client = app.test_client()
|
||||
try:
|
||||
with patch("src.services.speaker_embedding_matcher.find_matching_speakers", return_value={}):
|
||||
resp = client.get(f"/api/v1/recordings/{rec.id}/speakers",
|
||||
headers={"X-API-Token": token})
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
body = resp.get_json()
|
||||
labels = [s["label"] for s in body.get("speakers", [])]
|
||||
assert "SPEAKER_00" in labels and "SPEAKER_01" in labels, f"Labels: {labels}"
|
||||
return True
|
||||
finally:
|
||||
_cleanup(rec, token_rec)
|
||||
if cu:
|
||||
_cleanup(user)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Runner
|
||||
# =========================================================================
|
||||
|
||||
ALL_TESTS = [
|
||||
# Group 1: assign
|
||||
test_assign_no_auth,
|
||||
test_assign_recording_not_found,
|
||||
test_assign_wrong_user_recording,
|
||||
test_assign_missing_speaker_map,
|
||||
test_assign_invalid_speaker_map_type,
|
||||
test_assign_string_value_json_transcript,
|
||||
test_assign_object_value_with_name,
|
||||
test_assign_is_me_flag_with_user_name,
|
||||
test_assign_is_me_flag_without_user_name,
|
||||
test_assign_plain_text_transcript,
|
||||
test_assign_speaker_xx_filtered_from_participants,
|
||||
test_assign_invalid_value_type,
|
||||
test_assign_empty_speaker_map,
|
||||
test_assign_regenerate_summary,
|
||||
test_assign_embeddings_updated,
|
||||
test_assign_no_transcription,
|
||||
test_assign_whitespace_name_trimmed,
|
||||
# Group 2: identify
|
||||
test_identify_no_auth,
|
||||
test_identify_recording_not_found,
|
||||
test_identify_wrong_user_recording,
|
||||
test_identify_no_transcription,
|
||||
test_identify_non_json_transcription,
|
||||
test_identify_json_but_not_list,
|
||||
test_identify_happy_path,
|
||||
test_identify_post_processing_unknown_values,
|
||||
test_identify_no_speakers_in_transcript,
|
||||
test_identify_llm_error,
|
||||
# Group 3: auto-summarization
|
||||
test_auto_summarization_no_auth,
|
||||
test_auto_summarization_missing_enabled,
|
||||
test_auto_summarization_invalid_json,
|
||||
test_auto_summarization_enable,
|
||||
test_auto_summarization_disable,
|
||||
# Group 4: regression
|
||||
test_regression_get_speakers_list,
|
||||
test_regression_get_recording_speakers,
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
print(f"Running {len(ALL_TESTS)} Speaker API tests...\n")
|
||||
passed = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
for test_fn in ALL_TESTS:
|
||||
name = test_fn.__name__
|
||||
try:
|
||||
result = test_fn()
|
||||
if result:
|
||||
print(f" PASS {name}")
|
||||
passed += 1
|
||||
else:
|
||||
print(f" FAIL {name} (returned False)")
|
||||
failed += 1
|
||||
errors.append(name)
|
||||
except Exception as e:
|
||||
print(f" ERROR {name}: {e}")
|
||||
failed += 1
|
||||
errors.append(name)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Results: {passed} passed, {failed} failed out of {len(ALL_TESTS)}")
|
||||
if errors:
|
||||
print("Failed tests:")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
print('=' * 60)
|
||||
sys.exit(0 if failed == 0 else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
tests/test_api_v1_upload.py
Normal file
82
tests/test_api_v1_upload.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration test for API v1 recording upload endpoint.
|
||||
|
||||
Validates API token authentication and expected 400 response when no file is provided.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the parent directory to the path to import app
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.app import app, db
|
||||
from src.models import User, APIToken
|
||||
from src.utils.token_auth import hash_token
|
||||
|
||||
|
||||
def _get_or_create_test_user():
|
||||
user = User.query.filter_by(username="api_test_user").first()
|
||||
created = False
|
||||
if not user:
|
||||
user = User(username="api_test_user", email="api_test_user@local.test")
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
created = True
|
||||
return user, created
|
||||
|
||||
|
||||
def _create_api_token(user):
|
||||
plaintext = f"test-token-{secrets.token_urlsafe(16)}"
|
||||
token = APIToken(
|
||||
user_id=user.id,
|
||||
token_hash=hash_token(plaintext),
|
||||
name="test-api-token"
|
||||
)
|
||||
db.session.add(token)
|
||||
db.session.commit()
|
||||
return token, plaintext
|
||||
|
||||
|
||||
def test_upload_requires_file():
|
||||
with app.app_context():
|
||||
user, created_user = _get_or_create_test_user()
|
||||
token_record, token = _create_api_token(user)
|
||||
client = app.test_client()
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/v1/recordings/upload",
|
||||
headers={"X-API-Token": token}
|
||||
)
|
||||
|
||||
if response.status_code != 400:
|
||||
print(f"❌ Expected 400, got {response.status_code}")
|
||||
return False
|
||||
|
||||
payload = response.get_json(silent=True) or {}
|
||||
if payload.get("error") != "No file provided":
|
||||
print(f"❌ Unexpected error payload: {payload}")
|
||||
return False
|
||||
|
||||
print("✅ Token auth works and missing file returns 400 as expected")
|
||||
return True
|
||||
finally:
|
||||
db.session.delete(token_record)
|
||||
db.session.commit()
|
||||
if created_user:
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def main():
|
||||
print("🚀 Running API v1 upload test...\n")
|
||||
ok = test_upload_requires_file()
|
||||
print("\n" + ("✅ PASS" if ok else "❌ FAIL"))
|
||||
sys.exit(0 if ok else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
331
tests/test_audit.py
Normal file
331
tests/test_audit.py
Normal file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for the Loi 25 audit system.
|
||||
|
||||
Covers:
|
||||
- audit_access(): adds to session, does NOT commit
|
||||
- audit_login(): commits independently
|
||||
- audit_failed_login(): commits, uses email_hash (not plain email)
|
||||
- _is_recent_duplicate(): deduplication window
|
||||
- get_access_logs() / get_auth_logs(): pagination and filters
|
||||
- Admin API endpoints: /api/admin/audit/status, /access, /auth
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
|
||||
# Add the parent directory to the path to import app
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
os.environ.setdefault('SQLALCHEMY_DATABASE_URI', 'sqlite:///:memory:')
|
||||
os.environ.setdefault('ENABLE_AUDIT_LOG', 'true')
|
||||
os.environ['ENABLE_AUDIT_LOG'] = 'true'
|
||||
|
||||
from src.app import app, db
|
||||
from src.models import User
|
||||
from src.models.access_log import AccessLog
|
||||
from src.models.auth_log import AuthLog
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_user(username, is_admin=False):
|
||||
user = User(username=username, email=f"{username}@test.local", is_admin=is_admin)
|
||||
user.set_password("TestPass1!")
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _login_client(client, user):
|
||||
"""Push a real Flask-Login session for a user via test request context."""
|
||||
from flask_login import login_user
|
||||
with client.session_transaction() as sess:
|
||||
pass
|
||||
# Use the test client's post to login
|
||||
with app.test_request_context():
|
||||
login_user(user)
|
||||
# Directly inject the user_id into the session cookie
|
||||
with client.session_transaction() as sess:
|
||||
sess['_user_id'] = str(user.id)
|
||||
sess['_fresh'] = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service-level tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_audit_access_no_commit():
|
||||
"""audit_access() adds to session but does NOT commit."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
user = _make_user("audit_no_commit")
|
||||
try:
|
||||
initial_count = AccessLog.query.count()
|
||||
|
||||
from src.services.audit import audit_access
|
||||
with app.test_request_context():
|
||||
from flask_login import login_user
|
||||
login_user(user)
|
||||
log = audit_access('edit', 'recording', 1, user_id=user.id)
|
||||
|
||||
# Log object returned but not yet in DB (no commit happened)
|
||||
assert log is not None, "audit_access should return a log object"
|
||||
# The session hasn't been committed, so count is still the same
|
||||
# (SQLite in :memory: — flush to verify it's in session)
|
||||
db.session.flush()
|
||||
assert AccessLog.query.count() == initial_count + 1, "Log should be in session after flush"
|
||||
db.session.rollback() # undo — simulates caller rollback
|
||||
assert AccessLog.query.count() == initial_count, "Log should be gone after rollback"
|
||||
|
||||
print("✅ audit_access() does not commit")
|
||||
return True
|
||||
finally:
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_audit_login_commits():
|
||||
"""audit_login() commits its own transaction."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
user = _make_user("audit_login_user")
|
||||
try:
|
||||
initial_count = AuthLog.query.count()
|
||||
|
||||
from src.services.audit import audit_login
|
||||
with app.test_request_context():
|
||||
audit_login(user.id)
|
||||
|
||||
# Should be committed — visible in a fresh query
|
||||
assert AuthLog.query.count() == initial_count + 1, "auth log should be committed"
|
||||
log = AuthLog.query.order_by(AuthLog.id.desc()).first()
|
||||
assert log.action == 'login'
|
||||
assert log.user_id == user.id
|
||||
|
||||
print("✅ audit_login() commits independently")
|
||||
return True
|
||||
finally:
|
||||
AuthLog.query.filter_by(user_id=user.id).delete()
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_audit_failed_login_uses_email_hash():
|
||||
"""audit_failed_login() stores email_hash, not plain email."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
initial_count = AuthLog.query.count()
|
||||
email = "attacker-target@example.com"
|
||||
email_hash = hashlib.sha256(email.lower().encode()).hexdigest()[:16]
|
||||
|
||||
from src.services.audit import audit_failed_login
|
||||
with app.test_request_context():
|
||||
audit_failed_login(details={'email_hash': email_hash, 'reason': 'wrong_password'})
|
||||
|
||||
assert AuthLog.query.count() == initial_count + 1
|
||||
log = AuthLog.query.order_by(AuthLog.id.desc()).first()
|
||||
assert log.action == 'failed_login'
|
||||
assert log.details is not None
|
||||
assert 'email_hash' in log.details, "Should store email_hash, not plain email"
|
||||
assert 'email' not in log.details, "Should NOT store plain email"
|
||||
assert log.details['email_hash'] == email_hash
|
||||
|
||||
# Cleanup
|
||||
db.session.delete(log)
|
||||
db.session.commit()
|
||||
|
||||
print("✅ audit_failed_login() stores email_hash, not plain email")
|
||||
return True
|
||||
|
||||
|
||||
def test_audit_view_deduplication():
|
||||
"""audit_view() on the same resource within 5 min creates only one log entry."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
user = _make_user("audit_dedup_user")
|
||||
try:
|
||||
initial_count = AccessLog.query.count()
|
||||
|
||||
from src.services.audit import audit_view
|
||||
with app.test_request_context():
|
||||
from flask_login import login_user
|
||||
login_user(user)
|
||||
# First view — should log
|
||||
log1 = audit_view('recording', 42, user_id=user.id)
|
||||
db.session.commit()
|
||||
# Second view within 5 min — should be deduped
|
||||
log2 = audit_view('recording', 42, user_id=user.id)
|
||||
if log2 is not None:
|
||||
db.session.commit()
|
||||
|
||||
count_after = AccessLog.query.filter_by(
|
||||
user_id=user.id, action='view', resource_type='recording', resource_id=42
|
||||
).count()
|
||||
assert count_after == 1, f"Expected 1 log entry, got {count_after} (dedup failed)"
|
||||
|
||||
print("✅ audit_view() deduplication works (5-min window)")
|
||||
return True
|
||||
finally:
|
||||
AccessLog.query.filter_by(user_id=user.id).delete()
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_get_access_logs_pagination():
|
||||
"""get_access_logs() returns paginated results."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
user = _make_user("audit_pag_user")
|
||||
try:
|
||||
# Create 5 access log entries
|
||||
for i in range(5):
|
||||
log = AccessLog.log_access(
|
||||
action='view', resource_type='recording', resource_id=i,
|
||||
user_id=user.id, status='success',
|
||||
)
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
|
||||
from src.services.audit import get_access_logs
|
||||
page1 = get_access_logs(page=1, per_page=3, user_id=user.id)
|
||||
assert page1.total >= 5
|
||||
assert len(page1.items) == 3
|
||||
|
||||
page2 = get_access_logs(page=2, per_page=3, user_id=user.id)
|
||||
assert len(page2.items) >= 2
|
||||
|
||||
print("✅ get_access_logs() pagination works")
|
||||
return True
|
||||
finally:
|
||||
AccessLog.query.filter_by(user_id=user.id).delete()
|
||||
db.session.delete(user)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin API endpoint tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_audit_status_requires_admin():
|
||||
"""GET /api/admin/audit/status: 401 anon, 403 non-admin, 200 admin."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
regular = _make_user("audit_status_regular")
|
||||
admin = _make_user("audit_status_admin", is_admin=True)
|
||||
client = app.test_client()
|
||||
|
||||
try:
|
||||
# Anonymous — should redirect to login (302) or 401
|
||||
resp = client.get('/api/admin/audit/status')
|
||||
assert resp.status_code in (401, 302), f"Expected 401/302 for anon, got {resp.status_code}"
|
||||
|
||||
# Regular user — should get 403
|
||||
_login_client(client, regular)
|
||||
resp = client.get('/api/admin/audit/status')
|
||||
assert resp.status_code == 403, f"Expected 403 for non-admin, got {resp.status_code}"
|
||||
|
||||
# Admin — should get 200
|
||||
_login_client(client, admin)
|
||||
resp = client.get('/api/admin/audit/status')
|
||||
assert resp.status_code == 200, f"Expected 200 for admin, got {resp.status_code}"
|
||||
data = resp.get_json()
|
||||
assert 'enabled' in data
|
||||
|
||||
print("✅ /api/admin/audit/status access control works")
|
||||
return True
|
||||
finally:
|
||||
db.session.delete(regular)
|
||||
db.session.delete(admin)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_audit_access_logs_endpoint():
|
||||
"""GET /api/admin/audit/access returns paginated logs for admin."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
admin = _make_user("audit_access_ep_admin", is_admin=True)
|
||||
client = app.test_client()
|
||||
|
||||
try:
|
||||
_login_client(client, admin)
|
||||
resp = client.get('/api/admin/audit/access?per_page=10')
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
data = resp.get_json()
|
||||
assert 'logs' in data
|
||||
assert 'total' in data
|
||||
assert 'page' in data
|
||||
|
||||
print("✅ /api/admin/audit/access returns correct structure")
|
||||
return True
|
||||
finally:
|
||||
db.session.delete(admin)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_audit_auth_logs_endpoint():
|
||||
"""GET /api/admin/audit/auth returns paginated auth logs for admin."""
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
admin = _make_user("audit_auth_ep_admin", is_admin=True)
|
||||
client = app.test_client()
|
||||
|
||||
try:
|
||||
_login_client(client, admin)
|
||||
resp = client.get('/api/admin/audit/auth?per_page=10')
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
data = resp.get_json()
|
||||
assert 'logs' in data
|
||||
assert 'total' in data
|
||||
|
||||
print("✅ /api/admin/audit/auth returns correct structure")
|
||||
return True
|
||||
finally:
|
||||
db.session.delete(admin)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
print("🚀 Running audit system tests...\n")
|
||||
tests = [
|
||||
test_audit_access_no_commit,
|
||||
test_audit_login_commits,
|
||||
test_audit_failed_login_uses_email_hash,
|
||||
test_audit_view_deduplication,
|
||||
test_get_access_logs_pagination,
|
||||
test_audit_status_requires_admin,
|
||||
test_audit_access_logs_endpoint,
|
||||
test_audit_auth_logs_endpoint,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
for test in tests:
|
||||
try:
|
||||
result = test()
|
||||
if result:
|
||||
passed += 1
|
||||
else:
|
||||
print(f"❌ {test.__name__} returned False")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ {test.__name__} raised: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print(f"\n{'='*40}")
|
||||
print(f"Results: {passed} passed, {failed} failed")
|
||||
print("✅ ALL PASS" if failed == 0 else "❌ SOME FAILED")
|
||||
sys.exit(0 if failed == 0 else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
239
tests/test_bugfixes.py
Normal file
239
tests/test_bugfixes.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for specific bug fixes.
|
||||
|
||||
- Issue #230: Bulk delete crash when recordings have speaker_snippets
|
||||
- Issue #223: File monitor stability time env var
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.app import app, db
|
||||
from src.models import User, Recording
|
||||
from src.models.speaker_snippet import SpeakerSnippet
|
||||
|
||||
# Disable CSRF for testing
|
||||
app.config['WTF_CSRF_ENABLED'] = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_or_create_user():
|
||||
user = User.query.filter_by(username="bugfix_test_user").first()
|
||||
if not user:
|
||||
user = User(username="bugfix_test_user", email="bugfix@local.test")
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _create_recording_with_snippets(user):
|
||||
"""Create a recording that has speaker_snippet records attached."""
|
||||
rec = Recording(
|
||||
user_id=user.id,
|
||||
title="Recording with snippets",
|
||||
status="COMPLETED",
|
||||
transcription=json.dumps([
|
||||
{"speaker": "SPEAKER_00", "sentence": "Hello there."},
|
||||
]),
|
||||
)
|
||||
db.session.add(rec)
|
||||
db.session.commit()
|
||||
|
||||
# We need a speaker to attach snippets to
|
||||
from src.models import Speaker
|
||||
speaker = Speaker.query.filter_by(user_id=user.id, name="BugfixTestSpeaker").first()
|
||||
if not speaker:
|
||||
speaker = Speaker(name="BugfixTestSpeaker", user_id=user.id)
|
||||
db.session.add(speaker)
|
||||
db.session.commit()
|
||||
|
||||
snippet = SpeakerSnippet(
|
||||
speaker_id=speaker.id,
|
||||
recording_id=rec.id,
|
||||
segment_index=0,
|
||||
text_snippet="Hello there.",
|
||||
)
|
||||
db.session.add(snippet)
|
||||
db.session.commit()
|
||||
|
||||
return rec, speaker, snippet
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Issue #230: Deleting recordings with speaker_snippets
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIssue230BulkDeleteCascade:
|
||||
"""Verify that deleting a recording with speaker_snippets doesn't crash."""
|
||||
|
||||
def test_single_delete_with_snippets(self):
|
||||
"""Single DELETE /recording/<id> should succeed when snippets exist."""
|
||||
with app.app_context():
|
||||
user = _get_or_create_user()
|
||||
rec, speaker, snippet = _create_recording_with_snippets(user)
|
||||
rec_id = rec.id
|
||||
snippet_id = snippet.id
|
||||
|
||||
with app.test_client() as client:
|
||||
# Login
|
||||
with client.session_transaction() as sess:
|
||||
sess['_user_id'] = str(user.id)
|
||||
|
||||
resp = client.delete(f'/recording/{rec_id}')
|
||||
assert resp.status_code == 200, f"Delete failed: {resp.get_json()}"
|
||||
data = resp.get_json()
|
||||
assert data.get('success') is True
|
||||
|
||||
# Verify snippet was also deleted
|
||||
orphan = db.session.get(SpeakerSnippet, snippet_id)
|
||||
assert orphan is None, "Speaker snippet should have been deleted with recording"
|
||||
|
||||
# Cleanup speaker
|
||||
db.session.delete(speaker)
|
||||
db.session.commit()
|
||||
|
||||
def test_bulk_delete_with_snippets(self):
|
||||
"""DELETE /api/recordings/bulk should succeed when snippets exist."""
|
||||
with app.app_context():
|
||||
user = _get_or_create_user()
|
||||
rec, speaker, snippet = _create_recording_with_snippets(user)
|
||||
rec_id = rec.id
|
||||
snippet_id = snippet.id
|
||||
|
||||
with app.test_client() as client:
|
||||
with client.session_transaction() as sess:
|
||||
sess['_user_id'] = str(user.id)
|
||||
|
||||
resp = client.delete(
|
||||
'/api/recordings/bulk',
|
||||
json={'recording_ids': [rec_id]},
|
||||
content_type='application/json',
|
||||
)
|
||||
assert resp.status_code == 200, f"Bulk delete failed: {resp.get_json()}"
|
||||
data = resp.get_json()
|
||||
assert data.get('success') is True
|
||||
assert rec_id in data.get('deleted_ids', [])
|
||||
|
||||
# Verify snippet was also deleted
|
||||
orphan = db.session.get(SpeakerSnippet, snippet_id)
|
||||
assert orphan is None, "Speaker snippet should have been deleted with recording"
|
||||
|
||||
# Cleanup speaker
|
||||
db.session.delete(speaker)
|
||||
db.session.commit()
|
||||
|
||||
def test_bulk_delete_multiple_with_snippets(self):
|
||||
"""Bulk deleting multiple recordings (some with snippets) should succeed."""
|
||||
with app.app_context():
|
||||
user = _get_or_create_user()
|
||||
rec1, speaker, snippet = _create_recording_with_snippets(user)
|
||||
rec2 = Recording(user_id=user.id, title="No snippets", status="COMPLETED")
|
||||
db.session.add(rec2)
|
||||
db.session.commit()
|
||||
|
||||
rec1_id, rec2_id = rec1.id, rec2.id
|
||||
|
||||
with app.test_client() as client:
|
||||
with client.session_transaction() as sess:
|
||||
sess['_user_id'] = str(user.id)
|
||||
|
||||
resp = client.delete(
|
||||
'/api/recordings/bulk',
|
||||
json={'recording_ids': [rec1_id, rec2_id]},
|
||||
content_type='application/json',
|
||||
)
|
||||
assert resp.status_code == 200, f"Bulk delete failed: {resp.get_json()}"
|
||||
data = resp.get_json()
|
||||
assert data.get('deleted_count') == 2
|
||||
|
||||
# Cleanup speaker
|
||||
db.session.delete(speaker)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Issue #223: File monitor stability time
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIssue223StabilityTime:
|
||||
"""Verify AUTO_PROCESS_STABILITY_TIME env var is respected."""
|
||||
|
||||
def test_default_stability_time(self):
|
||||
"""Without env var, stability_time defaults to 5."""
|
||||
from src.file_monitor import FileMonitor
|
||||
monitor = FileMonitor.__new__(FileMonitor)
|
||||
monitor.logger = MagicMock()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
||||
f.write(b'fake audio data')
|
||||
tmp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
# Remove the env var if it exists
|
||||
os.environ.pop('AUTO_PROCESS_STABILITY_TIME', None)
|
||||
with patch('time.sleep') as mock_sleep:
|
||||
monitor._is_file_stable(tmp_path)
|
||||
mock_sleep.assert_called_once_with(5)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
def test_custom_stability_time(self):
|
||||
"""AUTO_PROCESS_STABILITY_TIME=15 should sleep for 15 seconds."""
|
||||
from src.file_monitor import FileMonitor
|
||||
monitor = FileMonitor.__new__(FileMonitor)
|
||||
monitor.logger = MagicMock()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
||||
f.write(b'fake audio data')
|
||||
tmp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with patch.dict(os.environ, {'AUTO_PROCESS_STABILITY_TIME': '15'}):
|
||||
with patch('time.sleep') as mock_sleep:
|
||||
# _is_file_stable uses the default param, but the caller reads env
|
||||
# So we test the caller path via _scan_user_directory indirectly
|
||||
# or just call with explicit value
|
||||
stability_time = int(os.environ.get('AUTO_PROCESS_STABILITY_TIME', '5'))
|
||||
monitor._is_file_stable(tmp_path, stability_time)
|
||||
mock_sleep.assert_called_once_with(15)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
def test_no_hardcoded_cap(self):
|
||||
"""Stability time should NOT be capped at 2 seconds anymore."""
|
||||
from src.file_monitor import FileMonitor
|
||||
monitor = FileMonitor.__new__(FileMonitor)
|
||||
monitor.logger = MagicMock()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
||||
f.write(b'fake audio data')
|
||||
tmp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with patch('time.sleep') as mock_sleep:
|
||||
monitor._is_file_stable(tmp_path, stability_time=30)
|
||||
# Should sleep for 30, NOT min(30, 2) = 2
|
||||
mock_sleep.assert_called_once_with(30)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
564
tests/test_connector_architecture.py
Normal file
564
tests/test_connector_architecture.py
Normal file
@@ -0,0 +1,564 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the transcription connector architecture.
|
||||
|
||||
This script tests:
|
||||
1. Connector auto-detection from environment variables
|
||||
2. Backwards compatibility with legacy config
|
||||
3. Connector specifications and capabilities
|
||||
4. Chunking logic (connector-aware)
|
||||
5. Codec handling per connector
|
||||
6. Request/Response data types
|
||||
|
||||
Run with: docker exec speakr-dev python /app/tests/test_connector_architecture.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Test results tracking
|
||||
PASSED = 0
|
||||
FAILED = 0
|
||||
ERRORS = []
|
||||
|
||||
|
||||
def run_test(name, func):
|
||||
"""Run a test function and track results."""
|
||||
global PASSED, FAILED, ERRORS
|
||||
try:
|
||||
func()
|
||||
print(f" ✓ {name}")
|
||||
PASSED += 1
|
||||
except AssertionError as e:
|
||||
print(f" ✗ {name}: {e}")
|
||||
FAILED += 1
|
||||
ERRORS.append((name, str(e)))
|
||||
except Exception as e:
|
||||
print(f" ✗ {name}: EXCEPTION - {e}")
|
||||
FAILED += 1
|
||||
ERRORS.append((name, f"Exception: {e}"))
|
||||
|
||||
|
||||
def clear_env():
|
||||
"""Clear all transcription-related environment variables."""
|
||||
keys_to_clear = [
|
||||
'TRANSCRIPTION_CONNECTOR', 'TRANSCRIPTION_API_KEY', 'TRANSCRIPTION_BASE_URL',
|
||||
'TRANSCRIPTION_MODEL', 'WHISPER_MODEL', 'USE_ASR_ENDPOINT', 'ASR_BASE_URL',
|
||||
'ASR_DIARIZE', 'ASR_RETURN_SPEAKER_EMBEDDINGS', 'ASR_TIMEOUT',
|
||||
'ASR_MIN_SPEAKERS', 'ASR_MAX_SPEAKERS', 'ENABLE_CHUNKING', 'CHUNK_LIMIT',
|
||||
'CHUNK_OVERLAP_SECONDS', 'AUDIO_UNSUPPORTED_CODECS',
|
||||
]
|
||||
for key in keys_to_clear:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
|
||||
def reset_registry():
|
||||
"""Reset the connector registry singleton."""
|
||||
from src.services.transcription import registry
|
||||
registry._registry = None
|
||||
registry.ConnectorRegistry._instance = None
|
||||
registry.ConnectorRegistry._initialized = False
|
||||
registry.ConnectorRegistry._active_connector = None
|
||||
registry.ConnectorRegistry._connector_name = ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 1: Base Classes and Data Types
|
||||
# =============================================================================
|
||||
|
||||
def test_base_classes():
|
||||
"""Test base classes and data types."""
|
||||
print("\n=== Testing Base Classes ===")
|
||||
|
||||
from src.services.transcription.base import (
|
||||
TranscriptionCapability, ConnectorSpecifications, TranscriptionRequest,
|
||||
TranscriptionResponse, TranscriptionSegment,
|
||||
)
|
||||
|
||||
def t1():
|
||||
assert TranscriptionCapability.DIARIZATION is not None
|
||||
assert TranscriptionCapability.TIMESTAMPS is not None
|
||||
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL is not None
|
||||
run_test("TranscriptionCapability enum has expected values", t1)
|
||||
|
||||
def t2():
|
||||
specs = ConnectorSpecifications()
|
||||
assert specs.max_file_size_bytes is None
|
||||
assert specs.handles_chunking_internally is False
|
||||
assert specs.recommended_chunk_seconds == 600
|
||||
run_test("ConnectorSpecifications has correct defaults", t2)
|
||||
|
||||
def t3():
|
||||
specs = ConnectorSpecifications(
|
||||
max_file_size_bytes=25 * 1024 * 1024,
|
||||
handles_chunking_internally=True,
|
||||
unsupported_codecs=frozenset({'opus'})
|
||||
)
|
||||
assert specs.max_file_size_bytes == 25 * 1024 * 1024
|
||||
assert 'opus' in specs.unsupported_codecs
|
||||
run_test("ConnectorSpecifications with custom values", t3)
|
||||
|
||||
def t4():
|
||||
audio = io.BytesIO(b"fake audio data")
|
||||
request = TranscriptionRequest(audio_file=audio, filename="test.wav", diarize=True)
|
||||
assert request.filename == "test.wav"
|
||||
assert request.diarize is True
|
||||
run_test("TranscriptionRequest creation", t4)
|
||||
|
||||
def t5():
|
||||
segments = [
|
||||
TranscriptionSegment(text="Hello", speaker="SPEAKER_00", start_time=0.0, end_time=1.0),
|
||||
TranscriptionSegment(text="World", speaker="SPEAKER_01", start_time=1.0, end_time=2.0),
|
||||
]
|
||||
response = TranscriptionResponse(text="Hello World", segments=segments, provider="test")
|
||||
storage = response.to_storage_format()
|
||||
data = json.loads(storage)
|
||||
assert len(data) == 2
|
||||
assert data[0]['speaker'] == "SPEAKER_00"
|
||||
run_test("TranscriptionResponse to_storage_format", t5)
|
||||
|
||||
def t6():
|
||||
segments = [TranscriptionSegment(text="Hello", speaker="SPEAKER_00")]
|
||||
response = TranscriptionResponse(text="Hello", segments=segments)
|
||||
assert response.has_diarization() is True
|
||||
|
||||
response2 = TranscriptionResponse(text="Hello", segments=None)
|
||||
assert response2.has_diarization() is False
|
||||
run_test("TranscriptionResponse has_diarization", t6)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 2: Connector Auto-Detection
|
||||
# =============================================================================
|
||||
|
||||
def test_auto_detection():
|
||||
"""Test connector auto-detection from environment variables."""
|
||||
print("\n=== Testing Connector Auto-Detection ===")
|
||||
|
||||
from src.services.transcription.registry import get_registry
|
||||
|
||||
def t1():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['TRANSCRIPTION_CONNECTOR'] = 'openai_whisper'
|
||||
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
|
||||
os.environ['ASR_BASE_URL'] = 'http://should-be-ignored:9000'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'openai_whisper'
|
||||
run_test("Explicit TRANSCRIPTION_CONNECTOR takes priority", t1)
|
||||
|
||||
def t2():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['ASR_BASE_URL'] = 'http://whisperx:9000'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'asr_endpoint'
|
||||
run_test("ASR_BASE_URL auto-detects asr_endpoint", t2)
|
||||
|
||||
def t3():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['USE_ASR_ENDPOINT'] = 'true'
|
||||
os.environ['ASR_BASE_URL'] = 'http://whisperx:9000'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'asr_endpoint'
|
||||
run_test("Legacy USE_ASR_ENDPOINT=true still works", t3)
|
||||
|
||||
def t4():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
|
||||
os.environ['TRANSCRIPTION_MODEL'] = 'gpt-4o-transcribe-diarize'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'openai_transcribe'
|
||||
run_test("gpt-4o model auto-detects openai_transcribe", t4)
|
||||
|
||||
def t5():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
|
||||
os.environ['TRANSCRIPTION_MODEL'] = 'whisper-1'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'openai_whisper'
|
||||
run_test("whisper-1 model uses openai_whisper", t5)
|
||||
|
||||
def t6():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
|
||||
os.environ['WHISPER_MODEL'] = 'whisper-1'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'openai_whisper'
|
||||
run_test("Legacy WHISPER_MODEL still works", t6)
|
||||
|
||||
def t7():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'openai_whisper'
|
||||
run_test("Default falls back to openai_whisper", t7)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 3: Connector Specifications
|
||||
# =============================================================================
|
||||
|
||||
def test_connector_specifications():
|
||||
"""Test connector specifications are correctly defined."""
|
||||
print("\n=== Testing Connector Specifications ===")
|
||||
|
||||
from src.services.transcription.connectors.openai_whisper import OpenAIWhisperConnector
|
||||
from src.services.transcription.connectors.openai_transcribe import OpenAITranscribeConnector
|
||||
from src.services.transcription.connectors.asr_endpoint import ASREndpointConnector
|
||||
from src.services.transcription.base import TranscriptionCapability
|
||||
|
||||
def t1():
|
||||
specs = OpenAIWhisperConnector.SPECIFICATIONS
|
||||
assert specs.max_file_size_bytes == 25 * 1024 * 1024
|
||||
assert specs.handles_chunking_internally is False
|
||||
run_test("OpenAI Whisper has 25MB limit", t1)
|
||||
|
||||
def t2():
|
||||
specs = OpenAIWhisperConnector.SPECIFICATIONS
|
||||
assert specs.unsupported_codecs is not None
|
||||
assert 'opus' in specs.unsupported_codecs
|
||||
run_test("OpenAI Whisper declares opus as unsupported", t2)
|
||||
|
||||
def t3():
|
||||
specs = OpenAITranscribeConnector.SPECIFICATIONS
|
||||
assert specs.handles_chunking_internally is True
|
||||
assert specs.requires_chunking_param is True
|
||||
run_test("OpenAI Transcribe has internal chunking", t3)
|
||||
|
||||
def t4():
|
||||
specs = ASREndpointConnector.SPECIFICATIONS
|
||||
assert specs.max_file_size_bytes is None
|
||||
assert specs.handles_chunking_internally is True
|
||||
run_test("ASR Endpoint has no limits (handles internally)", t4)
|
||||
|
||||
def t5():
|
||||
assert TranscriptionCapability.DIARIZATION not in OpenAIWhisperConnector.CAPABILITIES
|
||||
run_test("OpenAI Whisper does NOT support diarization", t5)
|
||||
|
||||
def t6():
|
||||
# Diarization is added dynamically based on model at instance level
|
||||
connector = OpenAITranscribeConnector({'api_key': 'test', 'model': 'gpt-4o-transcribe-diarize'})
|
||||
assert TranscriptionCapability.DIARIZATION in connector.CAPABILITIES
|
||||
assert connector.supports_diarization is True
|
||||
run_test("OpenAI Transcribe with diarize model supports diarization", t6)
|
||||
|
||||
def t7():
|
||||
assert TranscriptionCapability.DIARIZATION in ASREndpointConnector.CAPABILITIES
|
||||
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL in ASREndpointConnector.CAPABILITIES
|
||||
run_test("ASR Endpoint supports diarization and speaker count control", t7)
|
||||
|
||||
def t8():
|
||||
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL not in OpenAIWhisperConnector.CAPABILITIES
|
||||
assert TranscriptionCapability.SPEAKER_COUNT_CONTROL not in OpenAITranscribeConnector.CAPABILITIES
|
||||
run_test("OpenAI connectors do NOT support speaker count control", t8)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 4: Chunking Logic
|
||||
# =============================================================================
|
||||
|
||||
def test_chunking_logic():
|
||||
"""Test connector-aware chunking logic."""
|
||||
print("\n=== Testing Chunking Logic ===")
|
||||
|
||||
from src.audio_chunking import get_effective_chunking_config
|
||||
from src.services.transcription.base import ConnectorSpecifications
|
||||
|
||||
def t1():
|
||||
clear_env()
|
||||
os.environ['ENABLE_CHUNKING'] = 'true'
|
||||
os.environ['CHUNK_LIMIT'] = '20MB'
|
||||
specs = ConnectorSpecifications(handles_chunking_internally=True)
|
||||
config = get_effective_chunking_config(specs)
|
||||
assert config.enabled is False
|
||||
assert config.source == 'connector_internal'
|
||||
run_test("Connector with internal chunking disables app chunking", t1)
|
||||
|
||||
def t2():
|
||||
clear_env()
|
||||
os.environ['ENABLE_CHUNKING'] = 'true'
|
||||
os.environ['CHUNK_LIMIT'] = '15MB'
|
||||
os.environ['CHUNK_OVERLAP_SECONDS'] = '5'
|
||||
specs = ConnectorSpecifications(handles_chunking_internally=False)
|
||||
config = get_effective_chunking_config(specs)
|
||||
assert config.enabled is True
|
||||
assert config.source == 'env'
|
||||
assert config.mode == 'size'
|
||||
assert config.limit_value == 15.0
|
||||
run_test("Connector without internal chunking uses ENV settings", t2)
|
||||
|
||||
def t3():
|
||||
clear_env()
|
||||
os.environ['ENABLE_CHUNKING'] = 'false'
|
||||
specs = ConnectorSpecifications(handles_chunking_internally=False)
|
||||
config = get_effective_chunking_config(specs)
|
||||
assert config.enabled is False
|
||||
assert config.source == 'disabled'
|
||||
run_test("ENABLE_CHUNKING=false disables chunking", t3)
|
||||
|
||||
def t4():
|
||||
clear_env()
|
||||
os.environ['ENABLE_CHUNKING'] = 'true'
|
||||
os.environ['CHUNK_LIMIT'] = '10m'
|
||||
specs = ConnectorSpecifications(handles_chunking_internally=False)
|
||||
config = get_effective_chunking_config(specs)
|
||||
assert config.enabled is True
|
||||
assert config.mode == 'duration'
|
||||
assert config.limit_value == 600.0
|
||||
run_test("Duration-based chunk limit parsing (10m = 600s)", t4)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 5: Codec Handling
|
||||
# =============================================================================
|
||||
|
||||
def test_codec_handling():
|
||||
"""Test codec handling with connector specifications."""
|
||||
print("\n=== Testing Codec Handling ===")
|
||||
|
||||
from src.services.transcription.base import ConnectorSpecifications
|
||||
|
||||
def reload_audio_module():
|
||||
"""Properly reload audio_conversion module with fresh env vars."""
|
||||
import sys
|
||||
# Remove relevant modules from cache to force fresh import
|
||||
# app_config reads AUDIO_UNSUPPORTED_CODECS at import time
|
||||
for mod_name in list(sys.modules.keys()):
|
||||
if mod_name.startswith('src.utils') or mod_name.startswith('src.config'):
|
||||
del sys.modules[mod_name]
|
||||
from src.utils import audio_conversion
|
||||
return audio_conversion
|
||||
|
||||
def t1():
|
||||
clear_env()
|
||||
mod = reload_audio_module()
|
||||
codecs = mod.get_supported_codecs()
|
||||
assert 'mp3' in codecs
|
||||
assert 'flac' in codecs
|
||||
run_test("Default supported codecs include common formats", t1)
|
||||
|
||||
def t2():
|
||||
clear_env()
|
||||
mod = reload_audio_module()
|
||||
specs = ConnectorSpecifications(unsupported_codecs=frozenset({'opus', 'vorbis'}))
|
||||
codecs = mod.get_supported_codecs(connector_specs=specs)
|
||||
assert 'opus' not in codecs
|
||||
assert 'vorbis' not in codecs
|
||||
assert 'mp3' in codecs
|
||||
run_test("Connector unsupported_codecs removes from defaults", t2)
|
||||
|
||||
def t3():
|
||||
clear_env()
|
||||
os.environ['AUDIO_UNSUPPORTED_CODECS'] = 'aac,opus'
|
||||
mod = reload_audio_module()
|
||||
codecs = mod.get_supported_codecs()
|
||||
assert 'aac' not in codecs, f"aac should not be in {codecs}"
|
||||
assert 'opus' not in codecs, f"opus should not be in {codecs}"
|
||||
run_test("AUDIO_UNSUPPORTED_CODECS env var still works", t3)
|
||||
|
||||
def t4():
|
||||
clear_env()
|
||||
os.environ['AUDIO_UNSUPPORTED_CODECS'] = 'aac'
|
||||
mod = reload_audio_module()
|
||||
specs = ConnectorSpecifications(unsupported_codecs=frozenset({'opus'}))
|
||||
codecs = mod.get_supported_codecs(connector_specs=specs)
|
||||
assert 'aac' not in codecs, f"aac should not be in {codecs}"
|
||||
assert 'opus' not in codecs, f"opus should not be in {codecs}"
|
||||
assert 'mp3' in codecs
|
||||
run_test("Both connector specs and ENV var work together", t4)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 6: Connector Capabilities
|
||||
# =============================================================================
|
||||
|
||||
def test_connector_capabilities():
|
||||
"""Test connector capabilities are exposed correctly."""
|
||||
print("\n=== Testing Connector Capabilities ===")
|
||||
|
||||
from src.services.transcription.connectors.asr_endpoint import ASREndpointConnector
|
||||
from src.services.transcription.connectors.openai_transcribe import OpenAITranscribeConnector
|
||||
from src.services.transcription.base import TranscriptionCapability
|
||||
|
||||
def t1():
|
||||
connector = ASREndpointConnector({'base_url': 'http://test:9000'})
|
||||
assert connector.supports_diarization is True
|
||||
run_test("ASR connector supports_diarization property", t1)
|
||||
|
||||
def t2():
|
||||
connector = ASREndpointConnector({'base_url': 'http://test:9000'})
|
||||
assert connector.supports_speaker_count_control is True
|
||||
run_test("ASR connector supports_speaker_count_control property", t2)
|
||||
|
||||
def t3():
|
||||
connector = OpenAITranscribeConnector({'api_key': 'test-key', 'model': 'gpt-4o-transcribe-diarize'})
|
||||
assert connector.supports_diarization is True
|
||||
assert connector.supports_speaker_count_control is False
|
||||
run_test("OpenAI Transcribe supports diarization but not speaker_count_control", t3)
|
||||
|
||||
def t4():
|
||||
connector = ASREndpointConnector({'base_url': 'http://test:9000'})
|
||||
assert connector.supports(TranscriptionCapability.DIARIZATION) is True
|
||||
assert connector.supports(TranscriptionCapability.STREAMING) is False
|
||||
run_test("supports() method works correctly", t4)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 7: Registry Operations
|
||||
# =============================================================================
|
||||
|
||||
def test_registry_operations():
|
||||
"""Test registry listing and connector info."""
|
||||
print("\n=== Testing Registry Operations ===")
|
||||
|
||||
from src.services.transcription.registry import get_registry
|
||||
|
||||
def t1():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
registry = get_registry()
|
||||
connectors = registry.list_connectors()
|
||||
names = [c['name'] for c in connectors]
|
||||
assert 'openai_whisper' in names
|
||||
assert 'openai_transcribe' in names
|
||||
assert 'asr_endpoint' in names
|
||||
run_test("Registry lists all built-in connectors", t1)
|
||||
|
||||
def t2():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
registry = get_registry()
|
||||
connectors = registry.list_connectors()
|
||||
asr = next(c for c in connectors if c['name'] == 'asr_endpoint')
|
||||
assert 'DIARIZATION' in asr['capabilities']
|
||||
assert 'SPEAKER_COUNT_CONTROL' in asr['capabilities']
|
||||
run_test("Connector info includes capabilities", t2)
|
||||
|
||||
def t3():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['TRANSCRIPTION_API_KEY'] = 'test-key'
|
||||
os.environ['TRANSCRIPTION_MODEL'] = 'whisper-1'
|
||||
registry = get_registry()
|
||||
registry.initialize_from_env()
|
||||
assert registry.get_active_connector_name() == 'openai_whisper'
|
||||
|
||||
os.environ['TRANSCRIPTION_MODEL'] = 'gpt-4o-transcribe-diarize'
|
||||
registry.reinitialize()
|
||||
assert registry.get_active_connector_name() == 'openai_transcribe'
|
||||
run_test("reinitialize() resets the active connector", t3)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TEST SECTION 8: Edge Cases
|
||||
# =============================================================================
|
||||
|
||||
def test_edge_cases():
|
||||
"""Test edge cases and error handling."""
|
||||
print("\n=== Testing Edge Cases ===")
|
||||
|
||||
from src.services.transcription.registry import get_registry
|
||||
from src.services.transcription.exceptions import ConfigurationError
|
||||
from src.services.transcription.base import TranscriptionResponse, TranscriptionSegment
|
||||
|
||||
def t1():
|
||||
# Empty segments list returns the text (empty string), not "[]"
|
||||
response = TranscriptionResponse(text="", segments=[], provider="test")
|
||||
assert response.to_storage_format() == ""
|
||||
assert response.has_diarization() is False
|
||||
run_test("Empty transcription response handling", t1)
|
||||
|
||||
def t2():
|
||||
segments = [TranscriptionSegment(text="Hello", speaker=None)]
|
||||
response = TranscriptionResponse(text="Hello", segments=segments)
|
||||
storage = response.to_storage_format()
|
||||
data = json.loads(storage)
|
||||
assert data[0]['speaker'] == 'Unknown Speaker'
|
||||
run_test("Transcription with unknown speaker handling", t2)
|
||||
|
||||
def t3():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['TRANSCRIPTION_CONNECTOR'] = 'nonexistent_connector'
|
||||
registry = get_registry()
|
||||
try:
|
||||
registry.initialize_from_env()
|
||||
assert False, "Should have raised ConfigurationError"
|
||||
except ConfigurationError as e:
|
||||
assert 'Unknown connector' in str(e)
|
||||
run_test("Invalid connector name raises ConfigurationError", t3)
|
||||
|
||||
def t4():
|
||||
from src.services.transcription.connectors.asr_endpoint import ASREndpointConnector
|
||||
try:
|
||||
ASREndpointConnector({})
|
||||
assert False, "Should have raised ConfigurationError"
|
||||
except ConfigurationError as e:
|
||||
assert 'base_url' in str(e)
|
||||
run_test("ASR connector validates base_url is required", t4)
|
||||
|
||||
def t5():
|
||||
clear_env()
|
||||
reset_registry()
|
||||
os.environ['ASR_BASE_URL'] = 'http://whisperx:9000 # This is a comment'
|
||||
registry = get_registry()
|
||||
connector = registry.initialize_from_env()
|
||||
assert connector.base_url == 'http://whisperx:9000'
|
||||
run_test("ASR_BASE_URL with trailing comment is handled", t5)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
global PASSED, FAILED, ERRORS
|
||||
|
||||
print("=" * 60)
|
||||
print("Transcription Connector Architecture Tests")
|
||||
print("=" * 60)
|
||||
|
||||
test_base_classes()
|
||||
test_auto_detection()
|
||||
test_connector_specifications()
|
||||
test_chunking_logic()
|
||||
test_codec_handling()
|
||||
test_connector_capabilities()
|
||||
test_registry_operations()
|
||||
test_edge_cases()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"RESULTS: {PASSED} passed, {FAILED} failed")
|
||||
print("=" * 60)
|
||||
|
||||
if ERRORS:
|
||||
print("\nFailed tests:")
|
||||
for name, error in ERRORS:
|
||||
print(f" - {name}: {error}")
|
||||
|
||||
clear_env()
|
||||
return 0 if FAILED == 0 else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
319
tests/test_ffprobe_codec_detection.py
Normal file
319
tests/test_ffprobe_codec_detection.py
Normal file
@@ -0,0 +1,319 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for ffprobe codec detection functionality.
|
||||
|
||||
This script tests the new codec-based detection system to ensure it correctly
|
||||
identifies audio codecs, video files, and lossless formats.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.utils.ffprobe import (
|
||||
get_codec_info,
|
||||
is_video_file,
|
||||
is_audio_file,
|
||||
get_audio_codec,
|
||||
needs_audio_conversion,
|
||||
is_lossless_audio,
|
||||
get_duration,
|
||||
FFProbeError
|
||||
)
|
||||
|
||||
|
||||
def create_test_audio_file(codec, output_path, duration=1.0):
|
||||
"""Create a test audio file with specific codec."""
|
||||
codec_map = {
|
||||
'mp3': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'libmp3lame', '-b:a', '128k', output_path],
|
||||
'aac': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'aac', '-b:a', '128k', output_path],
|
||||
'opus': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'libopus', '-b:a', '64k', output_path],
|
||||
'flac': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'flac', output_path],
|
||||
'pcm_s16le': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'pcm_s16le', '-ar', '44100', output_path],
|
||||
'vorbis': ['ffmpeg', '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', '-acodec', 'libvorbis', '-b:a', '128k', output_path],
|
||||
}
|
||||
|
||||
if codec not in codec_map:
|
||||
raise ValueError(f"Unknown codec: {codec}")
|
||||
|
||||
subprocess.run(codec_map[codec], check=True, capture_output=True)
|
||||
|
||||
|
||||
def create_test_video_file(output_path, duration=1.0):
|
||||
"""Create a test video file with audio."""
|
||||
subprocess.run([
|
||||
'ffmpeg', '-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=320x240:rate=1',
|
||||
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
|
||||
'-acodec', 'aac', '-vcodec', 'libx264', '-pix_fmt', 'yuv420p',
|
||||
output_path
|
||||
], check=True, capture_output=True)
|
||||
|
||||
|
||||
def test_codec_detection():
|
||||
"""Test basic codec detection."""
|
||||
print("\n=== Testing Codec Detection ===\n")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_files = {
|
||||
'mp3': 'test.mp3',
|
||||
'aac': 'test.m4a',
|
||||
'opus': 'test.opus',
|
||||
'flac': 'test.flac',
|
||||
'pcm_s16le': 'test.wav',
|
||||
'vorbis': 'test.ogg',
|
||||
}
|
||||
|
||||
for codec, filename in test_files.items():
|
||||
filepath = os.path.join(tmpdir, filename)
|
||||
try:
|
||||
print(f"Creating test file: {filename} with codec {codec}...")
|
||||
create_test_audio_file(codec, filepath)
|
||||
|
||||
print(f" Probing {filename}...")
|
||||
codec_info = get_codec_info(filepath)
|
||||
|
||||
detected_codec = codec_info['audio_codec']
|
||||
print(f" ✓ Detected codec: {detected_codec}")
|
||||
print(f" Has audio: {codec_info['has_audio']}")
|
||||
print(f" Has video: {codec_info['has_video']}")
|
||||
print(f" Format: {codec_info['format_name']}")
|
||||
print(f" Duration: {codec_info['duration']:.2f}s" if codec_info['duration'] else " Duration: N/A")
|
||||
|
||||
if detected_codec != codec:
|
||||
print(f" ⚠️ Warning: Expected {codec}, got {detected_codec}")
|
||||
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to test {codec}: {e}\n")
|
||||
|
||||
|
||||
def test_video_detection():
|
||||
"""Test video file detection."""
|
||||
print("\n=== Testing Video Detection ===\n")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
video_path = os.path.join(tmpdir, 'test_video.mp4')
|
||||
audio_path = os.path.join(tmpdir, 'test_audio.mp3')
|
||||
|
||||
try:
|
||||
print("Creating test video file...")
|
||||
create_test_video_file(video_path)
|
||||
|
||||
print("Creating test audio file...")
|
||||
create_test_audio_file('mp3', audio_path)
|
||||
|
||||
print(f"\nProbing video file...")
|
||||
codec_info = get_codec_info(video_path)
|
||||
print(f" Audio codec: {codec_info['audio_codec']}")
|
||||
print(f" Video codec: {codec_info['video_codec']}")
|
||||
print(f" Has audio: {codec_info['has_audio']}")
|
||||
print(f" Has video: {codec_info['has_video']}")
|
||||
|
||||
is_video = is_video_file(video_path)
|
||||
print(f" is_video_file(): {is_video}")
|
||||
|
||||
if not is_video:
|
||||
print(" ✗ Video file not detected as video!")
|
||||
else:
|
||||
print(" ✓ Video file correctly detected")
|
||||
|
||||
print(f"\nProbing audio file...")
|
||||
codec_info = get_codec_info(audio_path)
|
||||
print(f" Audio codec: {codec_info['audio_codec']}")
|
||||
print(f" Video codec: {codec_info['video_codec']}")
|
||||
print(f" Has audio: {codec_info['has_audio']}")
|
||||
print(f" Has video: {codec_info['has_video']}")
|
||||
|
||||
is_video = is_video_file(audio_path)
|
||||
print(f" is_video_file(): {is_video}")
|
||||
|
||||
if is_video:
|
||||
print(" ✗ Audio file incorrectly detected as video!")
|
||||
else:
|
||||
print(" ✓ Audio file correctly identified as audio-only")
|
||||
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to test video detection: {e}\n")
|
||||
|
||||
|
||||
def test_lossless_detection():
|
||||
"""Test lossless audio detection."""
|
||||
print("\n=== Testing Lossless Detection ===\n")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_cases = {
|
||||
'pcm_s16le': ('test.wav', True),
|
||||
'flac': ('test.flac', True),
|
||||
'mp3': ('test.mp3', False),
|
||||
'aac': ('test.m4a', False),
|
||||
'opus': ('test.opus', False),
|
||||
}
|
||||
|
||||
for codec, (filename, expected_lossless) in test_cases.items():
|
||||
filepath = os.path.join(tmpdir, filename)
|
||||
try:
|
||||
print(f"Creating {filename} with codec {codec}...")
|
||||
create_test_audio_file(codec, filepath)
|
||||
|
||||
is_lossless = is_lossless_audio(filepath)
|
||||
status = "✓" if is_lossless == expected_lossless else "✗"
|
||||
|
||||
print(f" {status} {codec}: is_lossless={is_lossless} (expected {expected_lossless})")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to test {codec}: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def test_conversion_check():
|
||||
"""Test conversion requirement detection."""
|
||||
print("\n=== Testing Conversion Check ===\n")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Supported codecs for direct transcription
|
||||
supported_codecs = ['pcm_s16le', 'mp3', 'flac', 'opus', 'aac']
|
||||
|
||||
test_cases = {
|
||||
'mp3': ('test.mp3', False), # Supported, no conversion needed
|
||||
'aac': ('test.m4a', False), # Supported, no conversion needed
|
||||
'opus': ('test.opus', False), # Supported, no conversion needed
|
||||
'vorbis': ('test.ogg', True), # Not in supported list, needs conversion
|
||||
}
|
||||
|
||||
for codec, (filename, should_convert) in test_cases.items():
|
||||
filepath = os.path.join(tmpdir, filename)
|
||||
try:
|
||||
print(f"Creating {filename} with codec {codec}...")
|
||||
create_test_audio_file(codec, filepath)
|
||||
|
||||
needs_conversion, detected_codec = needs_audio_conversion(filepath, supported_codecs)
|
||||
status = "✓" if needs_conversion == should_convert else "✗"
|
||||
|
||||
print(f" {status} {codec}: needs_conversion={needs_conversion} (expected {should_convert})")
|
||||
print(f" Detected codec: {detected_codec}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to test {codec}: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def test_misnamed_file():
|
||||
"""Test detection of files with wrong extensions."""
|
||||
print("\n=== Testing Misnamed File Detection ===\n")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create an MP3 file but name it .wav
|
||||
wrong_name_path = os.path.join(tmpdir, 'actually_mp3.wav')
|
||||
|
||||
try:
|
||||
print("Creating MP3 file with .wav extension...")
|
||||
create_test_audio_file('mp3', wrong_name_path)
|
||||
|
||||
codec_info = get_codec_info(wrong_name_path)
|
||||
detected_codec = codec_info['audio_codec']
|
||||
|
||||
print(f" Filename: actually_mp3.wav")
|
||||
print(f" Detected codec: {detected_codec}")
|
||||
|
||||
if detected_codec == 'mp3':
|
||||
print(" ✓ Correctly detected MP3 codec despite .wav extension")
|
||||
else:
|
||||
print(f" ✗ Incorrectly detected as {detected_codec}")
|
||||
|
||||
# Create a FLAC file but name it .mp3
|
||||
wrong_name_path2 = os.path.join(tmpdir, 'actually_flac.mp3')
|
||||
print("\nCreating FLAC file with .mp3 extension...")
|
||||
create_test_audio_file('flac', wrong_name_path2)
|
||||
|
||||
codec_info = get_codec_info(wrong_name_path2)
|
||||
detected_codec = codec_info['audio_codec']
|
||||
|
||||
print(f" Filename: actually_flac.mp3")
|
||||
print(f" Detected codec: {detected_codec}")
|
||||
|
||||
if detected_codec == 'flac':
|
||||
print(" ✓ Correctly detected FLAC codec despite .mp3 extension")
|
||||
else:
|
||||
print(f" ✗ Incorrectly detected as {detected_codec}")
|
||||
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to test misnamed files: {e}\n")
|
||||
|
||||
|
||||
def test_duration():
|
||||
"""Test duration extraction."""
|
||||
print("\n=== Testing Duration Extraction ===\n")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
durations = [1.0, 2.5, 5.0]
|
||||
|
||||
for expected_duration in durations:
|
||||
filepath = os.path.join(tmpdir, f'test_{expected_duration}s.mp3')
|
||||
try:
|
||||
print(f"Creating {expected_duration}s audio file...")
|
||||
create_test_audio_file('mp3', filepath, duration=expected_duration)
|
||||
|
||||
detected_duration = get_duration(filepath)
|
||||
|
||||
# Allow 0.1s tolerance for encoding variations
|
||||
if detected_duration and abs(detected_duration - expected_duration) < 0.1:
|
||||
print(f" ✓ Duration: {detected_duration:.2f}s (expected {expected_duration}s)")
|
||||
else:
|
||||
print(f" ✗ Duration: {detected_duration:.2f}s (expected {expected_duration}s)")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to test duration: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("FFProbe Codec Detection Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if ffmpeg/ffprobe are available
|
||||
try:
|
||||
subprocess.run(['ffprobe', '-version'], capture_output=True, check=True)
|
||||
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
print("\n✗ Error: ffmpeg/ffprobe not found. Please install ffmpeg to run tests.\n")
|
||||
return 1
|
||||
|
||||
try:
|
||||
test_codec_detection()
|
||||
test_video_detection()
|
||||
test_lossless_detection()
|
||||
test_conversion_check()
|
||||
test_misnamed_file()
|
||||
test_duration()
|
||||
|
||||
print("=" * 60)
|
||||
print("All tests completed!")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test suite failed with error: {e}\n")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
241
tests/test_hotwords.sh
Executable file
241
tests/test_hotwords.sh
Executable file
@@ -0,0 +1,241 @@
|
||||
#!/bin/bash
|
||||
# test_hotwords.sh - Test hotwords and initial_prompt features
|
||||
#
|
||||
# Usage:
|
||||
# ./tests/test_hotwords.sh <asr_url> <audio_file>
|
||||
# ./tests/test_hotwords.sh http://localhost:9000 temp/Recording\ 4.flac
|
||||
#
|
||||
# The audio should contain domain-specific words that Whisper tends to
|
||||
# misspell (brand names, acronyms, unusual proper nouns). The script runs
|
||||
# three transcriptions and compares the results.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
ASR_URL="${1:-http://localhost:9000}"
|
||||
AUDIO_FILE="${2:-temp/Recording 4.flac}"
|
||||
OUTPUT_DIR="temp/hotwords_test_results"
|
||||
|
||||
# Colors
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
CYAN='\033[0;36m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Configurable test values - adjust these for your audio
|
||||
HOTWORDS="Speakr,CTranslate2,PyAnnote,WhisperX"
|
||||
INITIAL_PROMPT="This is a meeting about AI-powered audio transcription tools including Speakr, CTranslate2, and PyAnnote."
|
||||
|
||||
echo -e "${CYAN}============================================${NC}"
|
||||
echo -e "${CYAN} Hotwords & Initial Prompt Test Suite${NC}"
|
||||
echo -e "${CYAN}============================================${NC}"
|
||||
echo ""
|
||||
echo -e "ASR URL: ${YELLOW}${ASR_URL}${NC}"
|
||||
echo -e "Audio file: ${YELLOW}${AUDIO_FILE}${NC}"
|
||||
echo -e "Hotwords: ${YELLOW}${HOTWORDS}${NC}"
|
||||
echo -e "Initial prompt: ${YELLOW}${INITIAL_PROMPT}${NC}"
|
||||
echo ""
|
||||
|
||||
# Verify audio file exists
|
||||
if [ ! -f "$AUDIO_FILE" ]; then
|
||||
echo -e "${RED}ERROR: Audio file not found: ${AUDIO_FILE}${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify ASR endpoint is reachable
|
||||
echo -n "Checking ASR endpoint... "
|
||||
if curl -sf "${ASR_URL}/" > /dev/null 2>&1 || curl -sf "${ASR_URL}/health" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}OK${NC}"
|
||||
else
|
||||
echo -e "${RED}FAILED${NC}"
|
||||
echo "Cannot reach ASR endpoint at ${ASR_URL}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# ==============================================================
|
||||
# Test 1: Baseline (no hints)
|
||||
# ==============================================================
|
||||
echo ""
|
||||
echo -e "${CYAN}--- Test 1: Baseline (no hotwords, no initial_prompt) ---${NC}"
|
||||
echo -n "Transcribing... "
|
||||
|
||||
BASELINE_FILE="$OUTPUT_DIR/1_baseline.json"
|
||||
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe" \
|
||||
-F "audio_file=@${AUDIO_FILE}" \
|
||||
-o "$BASELINE_FILE"
|
||||
|
||||
BASELINE_TEXT=$(python3 -c "
|
||||
import json
|
||||
d=json.load(open('$BASELINE_FILE'))
|
||||
t=d.get('text','')
|
||||
if isinstance(t, list):
|
||||
t=' '.join(seg.get('text','') for seg in t)
|
||||
print(t[:500])
|
||||
" 2>/dev/null || echo "PARSE_ERROR")
|
||||
echo -e "${GREEN}Done${NC}"
|
||||
echo -e "Preview: ${BASELINE_TEXT:0:200}..."
|
||||
echo ""
|
||||
|
||||
# ==============================================================
|
||||
# Test 2: With hotwords only
|
||||
# ==============================================================
|
||||
echo -e "${CYAN}--- Test 2: With hotwords ---${NC}"
|
||||
echo -n "Transcribing... "
|
||||
|
||||
HOTWORDS_FILE="$OUTPUT_DIR/2_with_hotwords.json"
|
||||
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe&hotwords=${HOTWORDS}" \
|
||||
-F "audio_file=@${AUDIO_FILE}" \
|
||||
-o "$HOTWORDS_FILE"
|
||||
|
||||
HOTWORDS_TEXT=$(python3 -c "
|
||||
import json
|
||||
d=json.load(open('$HOTWORDS_FILE'))
|
||||
t=d.get('text','')
|
||||
if isinstance(t, list):
|
||||
t=' '.join(seg.get('text','') for seg in t)
|
||||
print(t[:500])
|
||||
" 2>/dev/null || echo "PARSE_ERROR")
|
||||
echo -e "${GREEN}Done${NC}"
|
||||
echo -e "Preview: ${HOTWORDS_TEXT:0:200}..."
|
||||
echo ""
|
||||
|
||||
# ==============================================================
|
||||
# Test 3: With hotwords + initial_prompt
|
||||
# ==============================================================
|
||||
echo -e "${CYAN}--- Test 3: With hotwords + initial_prompt ---${NC}"
|
||||
echo -n "Transcribing... "
|
||||
|
||||
BOTH_FILE="$OUTPUT_DIR/3_with_both.json"
|
||||
ENCODED_PROMPT=$(python3 -c "import urllib.parse; print(urllib.parse.quote('$INITIAL_PROMPT'))")
|
||||
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe&hotwords=${HOTWORDS}&initial_prompt=${ENCODED_PROMPT}" \
|
||||
-F "audio_file=@${AUDIO_FILE}" \
|
||||
-o "$BOTH_FILE"
|
||||
|
||||
BOTH_TEXT=$(python3 -c "
|
||||
import json
|
||||
d=json.load(open('$BOTH_FILE'))
|
||||
t=d.get('text','')
|
||||
if isinstance(t, list):
|
||||
t=' '.join(seg.get('text','') for seg in t)
|
||||
print(t[:500])
|
||||
" 2>/dev/null || echo "PARSE_ERROR")
|
||||
echo -e "${GREEN}Done${NC}"
|
||||
echo -e "Preview: ${BOTH_TEXT:0:200}..."
|
||||
echo ""
|
||||
|
||||
# ==============================================================
|
||||
# Test 4: With initial_prompt only
|
||||
# ==============================================================
|
||||
echo -e "${CYAN}--- Test 4: With initial_prompt only ---${NC}"
|
||||
echo -n "Transcribing... "
|
||||
|
||||
PROMPT_FILE="$OUTPUT_DIR/4_with_initial_prompt.json"
|
||||
curl -sS -X POST "${ASR_URL}/asr?output=json&task=transcribe&initial_prompt=${ENCODED_PROMPT}" \
|
||||
-F "audio_file=@${AUDIO_FILE}" \
|
||||
-o "$PROMPT_FILE"
|
||||
|
||||
PROMPT_TEXT=$(python3 -c "
|
||||
import json
|
||||
d=json.load(open('$PROMPT_FILE'))
|
||||
t=d.get('text','')
|
||||
if isinstance(t, list):
|
||||
t=' '.join(seg.get('text','') for seg in t)
|
||||
print(t[:500])
|
||||
" 2>/dev/null || echo "PARSE_ERROR")
|
||||
echo -e "${GREEN}Done${NC}"
|
||||
echo -e "Preview: ${PROMPT_TEXT:0:200}..."
|
||||
echo ""
|
||||
|
||||
# ==============================================================
|
||||
# Comparison
|
||||
# ==============================================================
|
||||
echo -e "${CYAN}============================================${NC}"
|
||||
echo -e "${CYAN} Comparison Results${NC}"
|
||||
echo -e "${CYAN}============================================${NC}"
|
||||
echo ""
|
||||
|
||||
# Check if hotwords appear in outputs
|
||||
python3 << 'PYEOF'
|
||||
import json
|
||||
import os
|
||||
|
||||
def extract_text(data):
|
||||
"""Extract full text from ASR response, handling both string and segment list formats."""
|
||||
text = data.get("text", "")
|
||||
if isinstance(text, list):
|
||||
return " ".join(seg.get("text", "") for seg in text)
|
||||
return text
|
||||
|
||||
hotwords = ["Speakr", "CTranslate2", "PyAnnote", "WhisperX"]
|
||||
output_dir = os.environ.get("OUTPUT_DIR", "temp/hotwords_test_results")
|
||||
test_files = {
|
||||
"1. Baseline": f"{output_dir}/1_baseline.json",
|
||||
"2. Hotwords only": f"{output_dir}/2_with_hotwords.json",
|
||||
"3. Hotwords + prompt": f"{output_dir}/3_with_both.json",
|
||||
"4. Initial prompt only": f"{output_dir}/4_with_initial_prompt.json",
|
||||
}
|
||||
|
||||
print(f"{'Test':<25} | {'Hotword Matches':<20} | {'Words Found'}")
|
||||
print("-" * 75)
|
||||
|
||||
for label, filepath in test_files.items():
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
data = json.load(f)
|
||||
text = extract_text(data)
|
||||
found = []
|
||||
for hw in hotwords:
|
||||
if hw.lower() in text.lower():
|
||||
found.append(hw)
|
||||
match_str = f"{len(found)}/{len(hotwords)}"
|
||||
found_str = ", ".join(found) if found else "(none)"
|
||||
print(f"{label:<25} | {match_str:<20} | {found_str}")
|
||||
except Exception as e:
|
||||
print(f"{label:<25} | ERROR: {e}")
|
||||
|
||||
print()
|
||||
print("Full outputs saved to: " + output_dir)
|
||||
PYEOF
|
||||
|
||||
echo ""
|
||||
echo -e "${CYAN}============================================${NC}"
|
||||
echo -e "${CYAN} Precedence Test via Speakr API${NC}"
|
||||
echo -e "${CYAN}============================================${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}To test the full precedence chain (user → folder → tag → upload form),${NC}"
|
||||
echo -e "${YELLOW}use the Speakr web API with authentication:${NC}"
|
||||
echo ""
|
||||
echo -e "1. Set user-level defaults in Account Settings → Prompt Options"
|
||||
echo -e "2. Create a tag with different hotwords/initial_prompt"
|
||||
echo -e "3. Create a folder with different hotwords/initial_prompt"
|
||||
echo -e "4. Upload via API and check server logs for resolved values:"
|
||||
echo ""
|
||||
cat << 'EXAMPLE'
|
||||
# Upload with user defaults only (no tag, no folder)
|
||||
curl -X POST "https://your-speakr/upload" \
|
||||
-H "Authorization: Bearer YOUR_TOKEN" \
|
||||
-F "file=@test.flac"
|
||||
# → Should use user defaults
|
||||
|
||||
# Upload with a tag that has hotwords set
|
||||
curl -X POST "https://your-speakr/upload" \
|
||||
-H "Authorization: Bearer YOUR_TOKEN" \
|
||||
-F "file=@test.flac" \
|
||||
-F "tags=TAG_ID_WITH_HOTWORDS"
|
||||
# → Should use tag defaults (overrides user)
|
||||
|
||||
# Upload with explicit form values (highest priority)
|
||||
curl -X POST "https://your-speakr/upload" \
|
||||
-H "Authorization: Bearer YOUR_TOKEN" \
|
||||
-F "file=@test.flac" \
|
||||
-F "tags=TAG_ID_WITH_HOTWORDS" \
|
||||
-F "hotwords=FormOverride1,FormOverride2" \
|
||||
-F "initial_prompt=Form level prompt"
|
||||
# → Should use form values (overrides tag and user)
|
||||
EXAMPLE
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}Test complete!${NC}"
|
||||
187
tests/test_inquire_mode.py
Normal file
187
tests/test_inquire_mode.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for Inquire Mode functionality
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add the parent directory to the path to import app
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.app import app, db, User, Recording, TranscriptChunk, InquireSession, Tag
|
||||
|
||||
def test_database_models():
|
||||
"""Test that the new database models work correctly."""
|
||||
with app.app_context():
|
||||
print("🔍 Testing Inquire Mode Database Models...")
|
||||
|
||||
# Test that tables exist
|
||||
from sqlalchemy import inspect
|
||||
inspector = inspect(db.engine)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
required_tables = ['transcript_chunk', 'inquire_session']
|
||||
for table in required_tables:
|
||||
if table in tables:
|
||||
print(f"✅ Table '{table}' exists")
|
||||
else:
|
||||
print(f"❌ Table '{table}' missing")
|
||||
return False
|
||||
|
||||
# Test creating sample data
|
||||
try:
|
||||
# Create a test user (or get existing one)
|
||||
user = User.query.first()
|
||||
if not user:
|
||||
print("❌ No users found. Please create a user first.")
|
||||
return False
|
||||
|
||||
print(f"📝 Using test user: {user.username}")
|
||||
|
||||
# Create a test recording if none exist
|
||||
recording = Recording.query.filter_by(user_id=user.id).first()
|
||||
if not recording:
|
||||
print("❌ No recordings found. Please create a recording first.")
|
||||
return False
|
||||
|
||||
print(f"🎵 Using test recording: {recording.title}")
|
||||
|
||||
# Test TranscriptChunk creation
|
||||
chunk = TranscriptChunk(
|
||||
recording_id=recording.id,
|
||||
user_id=user.id,
|
||||
chunk_index=0,
|
||||
content="This is a test transcription chunk.",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
speaker_name="Test Speaker"
|
||||
)
|
||||
|
||||
db.session.add(chunk)
|
||||
|
||||
# Test InquireSession creation
|
||||
session = InquireSession(
|
||||
user_id=user.id,
|
||||
session_name="Test Session",
|
||||
filter_tags='[]',
|
||||
filter_speakers='["Test Speaker"]'
|
||||
)
|
||||
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
|
||||
print("✅ Successfully created test TranscriptChunk and InquireSession")
|
||||
|
||||
# Clean up test data
|
||||
db.session.delete(chunk)
|
||||
db.session.delete(session)
|
||||
db.session.commit()
|
||||
|
||||
print("✅ Test data cleaned up")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing models: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_chunking_functions():
|
||||
"""Test the chunking and embedding functions."""
|
||||
with app.app_context():
|
||||
print("🔧 Testing Chunking Functions...")
|
||||
|
||||
try:
|
||||
from src.app import chunk_transcription, generate_embeddings, serialize_embedding, deserialize_embedding
|
||||
|
||||
# Test chunking
|
||||
test_text = "This is a test sentence. This is another sentence for testing. And here's a third sentence to make sure chunking works properly with longer text that should be split into multiple chunks."
|
||||
chunks = chunk_transcription(test_text, max_chunk_length=100, overlap=20)
|
||||
|
||||
if len(chunks) > 1:
|
||||
print(f"✅ Chunking works: {len(chunks)} chunks created")
|
||||
else:
|
||||
print("✅ Text too short for chunking (expected behavior)")
|
||||
|
||||
# Test embeddings (will only work if sentence-transformers is installed)
|
||||
try:
|
||||
embeddings = generate_embeddings(["test sentence", "another test"])
|
||||
if len(embeddings) == 2:
|
||||
print("✅ Embedding generation works")
|
||||
|
||||
# Test serialization
|
||||
if embeddings[0] is not None:
|
||||
serialized = serialize_embedding(embeddings[0])
|
||||
deserialized = deserialize_embedding(serialized)
|
||||
if deserialized is not None and len(deserialized) > 0:
|
||||
print("✅ Embedding serialization/deserialization works")
|
||||
else:
|
||||
print("❌ Embedding serialization/deserialization failed")
|
||||
else:
|
||||
print("❌ Embedding generation returned wrong number of embeddings")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Embedding test skipped (sentence-transformers may not be installed): {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing chunking functions: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_api_imports():
|
||||
"""Test that all API endpoints can be imported."""
|
||||
print("🔌 Testing API Endpoint Imports...")
|
||||
|
||||
try:
|
||||
from src.app import (
|
||||
get_inquire_sessions, create_inquire_session, inquire_search,
|
||||
inquire_chat, get_available_filters, process_recording_chunks_endpoint
|
||||
)
|
||||
print("✅ All inquire mode API endpoints imported successfully")
|
||||
return True
|
||||
except ImportError as e:
|
||||
print(f"❌ Failed to import API endpoints: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 Starting Inquire Mode Tests...\n")
|
||||
|
||||
tests = [
|
||||
("Database Models", test_database_models),
|
||||
("Chunking Functions", test_chunking_functions),
|
||||
("API Imports", test_api_imports)
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
print(f"\n--- {test_name} ---")
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"❌ {test_name} failed with exception: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📊 Test Results Summary:")
|
||||
print("="*50)
|
||||
|
||||
all_passed = True
|
||||
for test_name, success in results:
|
||||
status = "✅ PASS" if success else "❌ FAIL"
|
||||
print(f"{status} - {test_name}")
|
||||
if not success:
|
||||
all_passed = False
|
||||
|
||||
print("\n" + "="*50)
|
||||
if all_passed:
|
||||
print("🎉 All tests passed! Inquire Mode is ready to use.")
|
||||
else:
|
||||
print("⚠️ Some tests failed. Please check the output above.")
|
||||
|
||||
return all_passed
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
281
tests/test_job_queue_race_condition.py
Normal file
281
tests/test_job_queue_race_condition.py
Normal file
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for job queue race condition fix.
|
||||
|
||||
This script verifies that the atomic job claiming mechanism prevents
|
||||
multiple workers from claiming the same job simultaneously.
|
||||
|
||||
The fix uses an atomic UPDATE with WHERE clause to ensure only one
|
||||
worker can claim a job, even with multiple processes/threads.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
def test_atomic_job_claiming():
|
||||
"""
|
||||
Test that only one worker can claim a job even with concurrent attempts.
|
||||
|
||||
This simulates the race condition where multiple workers try to claim
|
||||
the same job simultaneously.
|
||||
"""
|
||||
print("\n=== Testing Atomic Job Claiming ===\n")
|
||||
|
||||
# Import Flask app and models
|
||||
from src.app import app
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob, User, Recording
|
||||
from sqlalchemy import update
|
||||
|
||||
with app.app_context():
|
||||
# Use the first existing user for testing, or create a minimal test user
|
||||
test_user = User.query.first()
|
||||
if not test_user:
|
||||
test_user = User(
|
||||
username='test_race_condition_user',
|
||||
email='test_race@example.com',
|
||||
password='not_used' # Password not needed for this test
|
||||
)
|
||||
db.session.add(test_user)
|
||||
db.session.commit()
|
||||
|
||||
# Create a test recording
|
||||
test_recording = Recording(
|
||||
user_id=test_user.id,
|
||||
title='Test Race Condition Recording',
|
||||
audio_path='/tmp/test_audio.mp3',
|
||||
status='QUEUED'
|
||||
)
|
||||
db.session.add(test_recording)
|
||||
db.session.commit()
|
||||
|
||||
# Create a test job in 'queued' status
|
||||
test_job = ProcessingJob(
|
||||
recording_id=test_recording.id,
|
||||
user_id=test_user.id,
|
||||
job_type='transcribe',
|
||||
status='queued'
|
||||
)
|
||||
db.session.add(test_job)
|
||||
db.session.commit()
|
||||
|
||||
job_id = test_job.id
|
||||
print(f"Created test job {job_id} with status 'queued'")
|
||||
|
||||
# Track which threads successfully claimed the job
|
||||
successful_claims = []
|
||||
claim_lock = threading.Lock()
|
||||
|
||||
def attempt_claim(worker_id):
|
||||
"""Simulate a worker attempting to claim the job."""
|
||||
with app.app_context():
|
||||
try:
|
||||
# This is the atomic claim logic from the fix
|
||||
claim_time = datetime.utcnow()
|
||||
result = db.session.execute(
|
||||
update(ProcessingJob)
|
||||
.where(
|
||||
ProcessingJob.id == job_id,
|
||||
ProcessingJob.status == 'queued'
|
||||
)
|
||||
.values(status='processing', started_at=claim_time)
|
||||
)
|
||||
|
||||
if result.rowcount == 1:
|
||||
db.session.commit()
|
||||
with claim_lock:
|
||||
successful_claims.append(worker_id)
|
||||
return f"Worker {worker_id}: Successfully claimed job"
|
||||
else:
|
||||
db.session.rollback()
|
||||
return f"Worker {worker_id}: Job already claimed (rowcount=0)"
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return f"Worker {worker_id}: Error - {e}"
|
||||
|
||||
# Spawn multiple threads to claim simultaneously
|
||||
num_workers = 10
|
||||
print(f"\nSpawning {num_workers} workers to claim job {job_id} simultaneously...")
|
||||
|
||||
# Use a barrier to ensure all threads start at the same time
|
||||
barrier = threading.Barrier(num_workers)
|
||||
|
||||
def worker_with_barrier(worker_id):
|
||||
barrier.wait() # Wait for all threads to be ready
|
||||
return attempt_claim(worker_id)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {executor.submit(worker_with_barrier, i): i for i in range(num_workers)}
|
||||
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
print(f" {result}")
|
||||
|
||||
# Verify results
|
||||
print(f"\n=== Results ===")
|
||||
print(f"Total workers: {num_workers}")
|
||||
print(f"Successful claims: {len(successful_claims)}")
|
||||
print(f"Workers that claimed: {successful_claims}")
|
||||
|
||||
# Check final job status
|
||||
db.session.expire_all()
|
||||
final_job = db.session.get(ProcessingJob, job_id)
|
||||
print(f"Final job status: {final_job.status}")
|
||||
|
||||
# Cleanup
|
||||
db.session.delete(final_job)
|
||||
db.session.delete(test_recording)
|
||||
db.session.commit()
|
||||
|
||||
# Assert only one worker claimed the job
|
||||
assert len(successful_claims) == 1, f"Expected 1 successful claim, got {len(successful_claims)}"
|
||||
assert final_job.status == 'processing', f"Expected status 'processing', got {final_job.status}"
|
||||
|
||||
print("\n[PASS] Only one worker successfully claimed the job!")
|
||||
return True
|
||||
|
||||
|
||||
def test_multiple_jobs_fair_distribution():
|
||||
"""
|
||||
Test that multiple jobs are distributed fairly across workers.
|
||||
"""
|
||||
print("\n=== Testing Multiple Jobs Distribution ===\n")
|
||||
|
||||
from src.app import app
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob, User, Recording
|
||||
from sqlalchemy import update
|
||||
|
||||
with app.app_context():
|
||||
# Use the first existing user for testing
|
||||
test_user = User.query.first()
|
||||
if not test_user:
|
||||
test_user = User(
|
||||
username='test_distribution_user',
|
||||
email='test_dist@example.com',
|
||||
password='not_used'
|
||||
)
|
||||
db.session.add(test_user)
|
||||
db.session.commit()
|
||||
|
||||
# Create multiple test jobs
|
||||
num_jobs = 5
|
||||
job_ids = []
|
||||
recording_ids = []
|
||||
|
||||
for i in range(num_jobs):
|
||||
recording = Recording(
|
||||
user_id=test_user.id,
|
||||
title=f'Test Distribution Recording {i}',
|
||||
audio_path=f'/tmp/test_audio_{i}.mp3',
|
||||
status='QUEUED'
|
||||
)
|
||||
db.session.add(recording)
|
||||
db.session.commit()
|
||||
recording_ids.append(recording.id)
|
||||
|
||||
job = ProcessingJob(
|
||||
recording_id=recording.id,
|
||||
user_id=test_user.id,
|
||||
job_type='transcribe',
|
||||
status='queued'
|
||||
)
|
||||
db.session.add(job)
|
||||
db.session.commit()
|
||||
job_ids.append(job.id)
|
||||
|
||||
print(f"Created {num_jobs} test jobs: {job_ids}")
|
||||
|
||||
# Have workers claim jobs
|
||||
claimed_jobs = []
|
||||
|
||||
def claim_any_job(worker_id):
|
||||
with app.app_context():
|
||||
# Find a queued job
|
||||
candidate = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'queued',
|
||||
ProcessingJob.job_type == 'transcribe'
|
||||
).first()
|
||||
|
||||
if not candidate:
|
||||
return None
|
||||
|
||||
# Atomic claim
|
||||
result = db.session.execute(
|
||||
update(ProcessingJob)
|
||||
.where(
|
||||
ProcessingJob.id == candidate.id,
|
||||
ProcessingJob.status == 'queued'
|
||||
)
|
||||
.values(status='processing', started_at=datetime.utcnow())
|
||||
)
|
||||
|
||||
if result.rowcount == 1:
|
||||
db.session.commit()
|
||||
return candidate.id
|
||||
else:
|
||||
db.session.rollback()
|
||||
return None
|
||||
|
||||
# Each "worker" claims one job
|
||||
for i in range(num_jobs + 2): # Extra attempts to ensure no double claims
|
||||
job_id = claim_any_job(i)
|
||||
if job_id:
|
||||
claimed_jobs.append(job_id)
|
||||
print(f" Worker {i} claimed job {job_id}")
|
||||
else:
|
||||
print(f" Worker {i} found no available jobs")
|
||||
|
||||
print(f"\nClaimed jobs: {claimed_jobs}")
|
||||
print(f"Unique jobs claimed: {len(set(claimed_jobs))}")
|
||||
|
||||
# Verify no duplicates
|
||||
assert len(claimed_jobs) == len(set(claimed_jobs)), "Duplicate job claims detected!"
|
||||
assert len(claimed_jobs) == num_jobs, f"Expected {num_jobs} claims, got {len(claimed_jobs)}"
|
||||
|
||||
# Cleanup
|
||||
for job_id in job_ids:
|
||||
job = db.session.get(ProcessingJob, job_id)
|
||||
if job:
|
||||
db.session.delete(job)
|
||||
for rec_id in recording_ids:
|
||||
rec = db.session.get(Recording, rec_id)
|
||||
if rec:
|
||||
db.session.delete(rec)
|
||||
db.session.commit()
|
||||
|
||||
print("\n[PASS] All jobs claimed exactly once!")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=" * 60)
|
||||
print("Job Queue Race Condition Tests")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_atomic_job_claiming()
|
||||
test_multiple_jobs_fair_distribution()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All tests passed!")
|
||||
print("=" * 60)
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n[FAIL] Test failed: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] Unexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
58
tests/test_json_fix.py
Normal file
58
tests/test_json_fix.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the parent directory to the path to import app
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.app import auto_close_json, safe_json_loads
|
||||
|
||||
def run_tests():
|
||||
"""Runs a series of tests for the JSON fixing functions."""
|
||||
|
||||
test_cases_auto_close = {
|
||||
"Unterminated string": ('{"title": "Test", "summary": "This is a test', '{"title": "Test", "summary": "This is a test"}'),
|
||||
"Missing closing brace": ('{"title": "Test", "summary": "This is a test"}', '{"title": "Test", "summary": "This is a test"}'),
|
||||
"Missing closing bracket": ('[{"item": 1}, {"item": 2}', '[{"item": 1}, {"item": 2}]'),
|
||||
"Nested unterminated": ('{"data": {"items": [1, 2', '{"data": {"items": [1, 2]}}'),
|
||||
"String at the end": ('{"key": "value', '{"key": "value"}'),
|
||||
"Empty string": ('', ''),
|
||||
"Already valid": ('{"a": 1}', '{"a": 1}'),
|
||||
"Complex nested object": ('{"a": {"b": {"c": [1, 2, {"d": "e' , '{"a": {"b": {"c": [1, 2, {"d": "e"}]}}}')
|
||||
}
|
||||
|
||||
print("--- Testing auto_close_json ---")
|
||||
for name, (input_str, expected_str) in test_cases_auto_close.items():
|
||||
result = auto_close_json(input_str)
|
||||
print(f"Test: {name}")
|
||||
print(f" Input: '{input_str}'")
|
||||
print(f" Output: '{result}'")
|
||||
print(f" Expected: '{expected_str}'")
|
||||
assert result == expected_str, f"Failed: {name}"
|
||||
print(" Result: PASSED\n")
|
||||
|
||||
test_cases_safe_loads = {
|
||||
"Unterminated string": '{"title": "Test", "summary": "This is a test',
|
||||
"Markdown with unterminated JSON": '```json\n{"title": "Test", "summary": "This is a test\n```',
|
||||
"Missing closing brace": '{"title": "Test", "summary": "This is a test"}',
|
||||
"Valid JSON": '{"title": "Complete", "summary": "This is a complete JSON."}',
|
||||
"JSON with escaped quotes": '{"title": "Escaped", "summary": "This is a \\"test\\" with quotes."}',
|
||||
"Invalid JSON": 'this is not json',
|
||||
}
|
||||
|
||||
print("\n--- Testing safe_json_loads ---")
|
||||
for name, input_str in test_cases_safe_loads.items():
|
||||
result = safe_json_loads(input_str)
|
||||
print(f"Test: {name}")
|
||||
print(f" Input: '{input_str}'")
|
||||
print(f" Output: {result}")
|
||||
if name == "Invalid JSON":
|
||||
assert result is None, f"Failed: {name}"
|
||||
print(" Result: PASSED (Correctly returned None)\n")
|
||||
else:
|
||||
assert isinstance(result, dict), f"Failed: {name}"
|
||||
print(" Result: PASSED\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
print("All tests completed successfully!")
|
||||
291
tests/test_json_preprocessing.py
Normal file
291
tests/test_json_preprocessing.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for JSON preprocessing functionality in Speakr app.
|
||||
Tests the safe_json_loads function with various malformed JSON scenarios.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Add the app directory to the path so we can import from app.py
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Mock the Flask app and logger for testing
|
||||
class MockApp:
|
||||
def __init__(self):
|
||||
self.logger = Mock()
|
||||
|
||||
# Set up the mock app before importing
|
||||
app = MockApp()
|
||||
|
||||
# Import the functions we want to test
|
||||
from src.app import safe_json_loads, preprocess_json_escapes, extract_json_object
|
||||
|
||||
class TestJSONPreprocessing(unittest.TestCase):
|
||||
"""Test cases for JSON preprocessing functionality."""
|
||||
|
||||
def test_valid_json(self):
|
||||
"""Test that valid JSON is parsed correctly."""
|
||||
valid_json = '{"title": "Test Meeting", "summary": "This is a test summary"}'
|
||||
result = safe_json_loads(valid_json)
|
||||
expected = {"title": "Test Meeting", "summary": "This is a test summary"}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_json_with_markdown_code_blocks(self):
|
||||
"""Test JSON wrapped in markdown code blocks."""
|
||||
markdown_json = '''```json
|
||||
{
|
||||
"title": "Meeting Notes",
|
||||
"summary": "Key points discussed"
|
||||
}
|
||||
```'''
|
||||
result = safe_json_loads(markdown_json)
|
||||
expected = {"title": "Meeting Notes", "summary": "Key points discussed"}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_json_with_unescaped_quotes(self):
|
||||
"""Test JSON with unescaped quotes in string values."""
|
||||
malformed_json = '{"title": "John said "Hello world" to everyone", "summary": "Meeting summary"}'
|
||||
result = safe_json_loads(malformed_json)
|
||||
expected = {"title": 'John said "Hello world" to everyone', "summary": "Meeting summary"}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_json_with_mixed_quotes(self):
|
||||
"""Test JSON with mixed quote scenarios."""
|
||||
malformed_json = '{"title": "Alice\'s "big idea" presentation", "summary": "She said "this will change everything""}'
|
||||
result = safe_json_loads(malformed_json)
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("title", result)
|
||||
self.assertIn("summary", result)
|
||||
|
||||
def test_json_with_newlines_and_special_chars(self):
|
||||
"""Test JSON with newlines and special characters."""
|
||||
malformed_json = '''{"title": "Complex Meeting", "summary": "Discussion about:\n- Point 1\n- Point 2 with "quotes"\n- Point 3"}'''
|
||||
result = safe_json_loads(malformed_json)
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("title", result)
|
||||
self.assertIn("summary", result)
|
||||
|
||||
def test_empty_or_invalid_input(self):
|
||||
"""Test handling of empty or invalid input."""
|
||||
# Empty string
|
||||
result = safe_json_loads("", {"default": "value"})
|
||||
self.assertEqual(result, {"default": "value"})
|
||||
|
||||
# None input
|
||||
result = safe_json_loads(None, {"default": "value"})
|
||||
self.assertEqual(result, {"default": "value"})
|
||||
|
||||
# Non-string input
|
||||
result = safe_json_loads(123, {"default": "value"})
|
||||
self.assertEqual(result, {"default": "value"})
|
||||
|
||||
def test_completely_malformed_json(self):
|
||||
"""Test completely malformed JSON that can't be fixed."""
|
||||
malformed_json = '{"title": "Test", "summary": unclosed string and missing quotes}'
|
||||
result = safe_json_loads(malformed_json, {"error": "fallback"})
|
||||
self.assertEqual(result, {"error": "fallback"})
|
||||
|
||||
def test_json_with_nested_quotes(self):
|
||||
"""Test JSON with deeply nested quote scenarios."""
|
||||
malformed_json = '{"title": "Meeting about "Project Alpha" and "Project Beta"", "summary": "Both projects involve "cutting-edge" technology"}'
|
||||
result = safe_json_loads(malformed_json)
|
||||
self.assertIsInstance(result, dict)
|
||||
# Should have successfully parsed something
|
||||
self.assertTrue(len(result) > 0)
|
||||
|
||||
def test_json_array_format(self):
|
||||
"""Test JSON array format (for transcription data)."""
|
||||
json_array = '[{"speaker": "John", "sentence": "Hello everyone"}, {"speaker": "Jane", "sentence": "Good morning"}]'
|
||||
result = safe_json_loads(json_array)
|
||||
expected = [{"speaker": "John", "sentence": "Hello everyone"}, {"speaker": "Jane", "sentence": "Good morning"}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_preprocess_json_escapes_function(self):
|
||||
"""Test the preprocess_json_escapes function directly."""
|
||||
input_json = '{"title": "John said "Hello" to Mary", "summary": "Simple test"}'
|
||||
processed = preprocess_json_escapes(input_json)
|
||||
# Should be valid JSON after preprocessing
|
||||
result = json.loads(processed)
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("title", result)
|
||||
self.assertIn("summary", result)
|
||||
|
||||
def test_extract_json_object_function(self):
|
||||
"""Test the extract_json_object function directly."""
|
||||
# Test with extra text around JSON object
|
||||
text_with_json = 'Here is some text {"title": "Test", "summary": "Content"} and more text'
|
||||
extracted = extract_json_object(text_with_json)
|
||||
result = json.loads(extracted)
|
||||
expected = {"title": "Test", "summary": "Content"}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test with JSON array
|
||||
text_with_array = 'Some text [{"item": "one"}, {"item": "two"}] more text'
|
||||
extracted = extract_json_object(text_with_array)
|
||||
result = json.loads(extracted)
|
||||
expected = [{"item": "one"}, {"item": "two"}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_real_world_llm_response_scenarios(self):
|
||||
"""Test real-world scenarios that might come from LLM responses."""
|
||||
|
||||
# Scenario 1: LLM response with explanation text
|
||||
llm_response1 = '''Here's the JSON response you requested:
|
||||
|
||||
```json
|
||||
{
|
||||
"title": "Q3 Planning Meeting",
|
||||
"summary": "We discussed the "new initiative" and John's "breakthrough idea" for next quarter."
|
||||
}
|
||||
```
|
||||
|
||||
This should help with your transcription needs.'''
|
||||
|
||||
result1 = safe_json_loads(llm_response1)
|
||||
self.assertIsInstance(result1, dict)
|
||||
self.assertIn("title", result1)
|
||||
self.assertIn("summary", result1)
|
||||
|
||||
# Scenario 2: LLM response with unescaped quotes and no code blocks
|
||||
llm_response2 = '{"title": "Team Standup", "summary": "Alice mentioned "the deadline is tight" and Bob said "we need more resources""}'
|
||||
|
||||
result2 = safe_json_loads(llm_response2)
|
||||
self.assertIsInstance(result2, dict)
|
||||
self.assertIn("title", result2)
|
||||
self.assertIn("summary", result2)
|
||||
|
||||
# Scenario 3: LLM response with speaker identification
|
||||
llm_response3 = '''{"SPEAKER_00": "John Smith", "SPEAKER_01": "Jane "The Expert" Doe", "SPEAKER_02": "Bob"}'''
|
||||
|
||||
result3 = safe_json_loads(llm_response3)
|
||||
self.assertIsInstance(result3, dict)
|
||||
self.assertTrue(len(result3) >= 2) # Should have parsed at least some speakers
|
||||
|
||||
def test_fallback_strategies(self):
|
||||
"""Test that different parsing strategies work as fallbacks."""
|
||||
|
||||
# Test ast.literal_eval fallback for simple cases
|
||||
simple_dict = "{'title': 'Simple', 'summary': 'Test'}"
|
||||
result = safe_json_loads(simple_dict)
|
||||
expected = {"title": "Simple", "summary": "Test"}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test regex extraction fallback
|
||||
messy_response = 'Some text before {"title": "Extracted", "summary": "From regex"} some text after'
|
||||
result = safe_json_loads(messy_response)
|
||||
expected = {"title": "Extracted", "summary": "From regex"}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_performance_with_large_content(self):
|
||||
"""Test performance with larger JSON content."""
|
||||
large_summary = "This is a very long summary. " * 100 # Create a long string
|
||||
large_json = f'{{"title": "Large Content Test", "summary": "{large_summary}"}}'
|
||||
|
||||
result = safe_json_loads(large_json)
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn("title", result)
|
||||
self.assertIn("summary", result)
|
||||
self.assertEqual(result["title"], "Large Content Test")
|
||||
|
||||
def run_comprehensive_test():
|
||||
"""Run a comprehensive test with various malformed JSON examples."""
|
||||
print("🧪 Running comprehensive JSON preprocessing tests...\n")
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Valid JSON",
|
||||
"input": '{"title": "Test", "summary": "Valid JSON"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Unescaped quotes in title",
|
||||
"input": '{"title": "Meeting about "Project X"", "summary": "Discussion summary"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Multiple unescaped quotes",
|
||||
"input": '{"title": "John said "Hello" and Mary replied "Hi there"", "summary": "Conversation log"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Markdown code block",
|
||||
"input": '```json\n{"title": "Wrapped", "summary": "In code block"}\n```',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Mixed quotes and apostrophes",
|
||||
"input": '{"title": "Alice\'s "big idea" presentation", "summary": "She said it\'s "revolutionary""}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "JSON with newlines",
|
||||
"input": '{"title": "Multi-line", "summary": "Line 1\\nLine 2 with \\"quotes\\"\\nLine 3"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Completely malformed",
|
||||
"input": '{"title": "Test", "summary": this is not valid json at all}',
|
||||
"should_succeed": False
|
||||
},
|
||||
{
|
||||
"name": "Empty string",
|
||||
"input": "",
|
||||
"should_succeed": False
|
||||
}
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for i, test_case in enumerate(test_cases, 1):
|
||||
print(f"Test {i}: {test_case['name']}")
|
||||
print(f"Input: {test_case['input'][:100]}{'...' if len(test_case['input']) > 100 else ''}")
|
||||
|
||||
try:
|
||||
result = safe_json_loads(test_case['input'], {"error": "fallback"})
|
||||
|
||||
if test_case['should_succeed']:
|
||||
if result != {"error": "fallback"} and isinstance(result, dict):
|
||||
print("✅ PASSED - Successfully parsed JSON")
|
||||
passed += 1
|
||||
else:
|
||||
print("❌ FAILED - Expected successful parsing but got fallback")
|
||||
failed += 1
|
||||
else:
|
||||
if result == {"error": "fallback"}:
|
||||
print("✅ PASSED - Correctly returned fallback for malformed JSON")
|
||||
passed += 1
|
||||
else:
|
||||
print("❌ FAILED - Expected fallback but got parsed result")
|
||||
failed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED - Exception occurred: {e}")
|
||||
failed += 1
|
||||
|
||||
print("-" * 50)
|
||||
|
||||
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
|
||||
return failed == 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Starting JSON Preprocessing Tests for Speakr App\n")
|
||||
|
||||
# Run the comprehensive manual test
|
||||
manual_success = run_comprehensive_test()
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("🔬 Running Unit Tests")
|
||||
print("="*60)
|
||||
|
||||
# Run the unit tests
|
||||
unittest.main(argv=[''], exit=False, verbosity=2)
|
||||
|
||||
if manual_success:
|
||||
print("\n🎉 All tests completed! JSON preprocessing should handle LLM response issues gracefully.")
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please review the implementation.")
|
||||
340
tests/test_json_standalone.py
Normal file
340
tests/test_json_standalone.py
Normal file
@@ -0,0 +1,340 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone test for JSON preprocessing functionality.
|
||||
Tests the safe_json_loads function with various malformed JSON scenarios.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import ast
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Mock logger for testing
|
||||
class MockLogger:
|
||||
def warning(self, msg): print(f"WARNING: {msg}")
|
||||
def info(self, msg): print(f"INFO: {msg}")
|
||||
def debug(self, msg): print(f"DEBUG: {msg}")
|
||||
def error(self, msg): print(f"ERROR: {msg}")
|
||||
|
||||
# Create mock app with logger
|
||||
class MockApp:
|
||||
logger = MockLogger()
|
||||
|
||||
app = MockApp()
|
||||
|
||||
def safe_json_loads(json_string, fallback_value=None):
|
||||
"""
|
||||
Safely parse JSON with preprocessing to handle common LLM JSON formatting issues.
|
||||
|
||||
Args:
|
||||
json_string (str): The JSON string to parse
|
||||
fallback_value: Value to return if parsing fails (default: None)
|
||||
|
||||
Returns:
|
||||
Parsed JSON object or fallback_value if parsing fails
|
||||
"""
|
||||
if not json_string or not isinstance(json_string, str):
|
||||
app.logger.warning(f"Invalid JSON input: {type(json_string)} - {json_string}")
|
||||
return fallback_value
|
||||
|
||||
# Step 1: Clean the input string
|
||||
cleaned_json = json_string.strip()
|
||||
|
||||
# Step 2: Extract JSON from markdown code blocks if present
|
||||
json_match = re.search(r'```(?:json)?\s*(.*?)\s*```', cleaned_json, re.DOTALL)
|
||||
if json_match:
|
||||
cleaned_json = json_match.group(1).strip()
|
||||
|
||||
# Step 3: Try multiple parsing strategies
|
||||
parsing_strategies = [
|
||||
# Strategy 1: Direct parsing (for well-formed JSON)
|
||||
lambda x: json.loads(x),
|
||||
|
||||
# Strategy 2: Fix common escape issues
|
||||
lambda x: json.loads(preprocess_json_escapes(x)),
|
||||
|
||||
# Strategy 3: Use ast.literal_eval as fallback for simple cases
|
||||
lambda x: ast.literal_eval(x) if x.startswith(('{', '[')) else None,
|
||||
|
||||
# Strategy 4: Extract JSON object/array using regex
|
||||
lambda x: json.loads(extract_json_object(x)),
|
||||
]
|
||||
|
||||
for i, strategy in enumerate(parsing_strategies):
|
||||
try:
|
||||
result = strategy(cleaned_json)
|
||||
if result is not None:
|
||||
if i > 0: # Log if we had to use a fallback strategy
|
||||
app.logger.info(f"JSON parsed successfully using strategy {i+1}")
|
||||
return result
|
||||
except (json.JSONDecodeError, ValueError, SyntaxError) as e:
|
||||
if i == 0: # Only log the first failure to avoid spam
|
||||
app.logger.debug(f"JSON parsing strategy {i+1} failed: {e}")
|
||||
continue
|
||||
|
||||
# All strategies failed
|
||||
app.logger.error(f"All JSON parsing strategies failed for: {cleaned_json[:200]}...")
|
||||
return fallback_value
|
||||
|
||||
def preprocess_json_escapes(json_string):
|
||||
"""
|
||||
Preprocess JSON string to fix common escape issues from LLM responses.
|
||||
Uses a more sophisticated approach to handle nested quotes properly.
|
||||
"""
|
||||
if not json_string:
|
||||
return json_string
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
expecting_value = False # Track if we're expecting a value (after :)
|
||||
|
||||
while i < len(json_string):
|
||||
char = json_string[i]
|
||||
|
||||
if escape_next:
|
||||
# This character is escaped, add it as-is
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
elif char == '\\':
|
||||
# This is an escape character
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
elif char == ':' and not in_string:
|
||||
# We found a colon, next string will be a value
|
||||
result.append(char)
|
||||
expecting_value = True
|
||||
elif char == ',' and not in_string:
|
||||
# We found a comma, reset expecting_value
|
||||
result.append(char)
|
||||
expecting_value = False
|
||||
elif char == '"':
|
||||
if not in_string:
|
||||
# Starting a string
|
||||
in_string = True
|
||||
result.append(char)
|
||||
else:
|
||||
# We're in a string, check if this quote should be escaped
|
||||
# Look ahead to see if this is the end of the string value
|
||||
j = i + 1
|
||||
while j < len(json_string) and json_string[j].isspace():
|
||||
j += 1
|
||||
|
||||
# For keys (not expecting_value), only end on colon
|
||||
# For values (expecting_value), end on comma, closing brace, or closing bracket
|
||||
if expecting_value:
|
||||
end_chars = ',}]'
|
||||
else:
|
||||
end_chars = ':'
|
||||
|
||||
if j < len(json_string) and json_string[j] in end_chars:
|
||||
# This is the end of the string
|
||||
in_string = False
|
||||
result.append(char)
|
||||
if not expecting_value:
|
||||
# We just finished a key, next will be expecting value
|
||||
expecting_value = True
|
||||
else:
|
||||
# This is an inner quote that should be escaped
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
def extract_json_object(text):
|
||||
"""
|
||||
Extract the first complete JSON object or array from text using regex.
|
||||
"""
|
||||
# Look for JSON object
|
||||
obj_match = re.search(r'\{.*\}', text, re.DOTALL)
|
||||
if obj_match:
|
||||
return obj_match.group(0)
|
||||
|
||||
# Look for JSON array
|
||||
arr_match = re.search(r'\[.*\]', text, re.DOTALL)
|
||||
if arr_match:
|
||||
return arr_match.group(0)
|
||||
|
||||
# Return original if no JSON structure found
|
||||
return text
|
||||
|
||||
def run_comprehensive_test():
|
||||
"""Run a comprehensive test with various malformed JSON examples."""
|
||||
print("🧪 Running comprehensive JSON preprocessing tests...\n")
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Valid JSON",
|
||||
"input": '{"title": "Test", "summary": "Valid JSON"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Unescaped quotes in title",
|
||||
"input": '{"title": "Meeting about "Project X"", "summary": "Discussion summary"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Multiple unescaped quotes",
|
||||
"input": '{"title": "John said "Hello" and Mary replied "Hi there"", "summary": "Conversation log"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Markdown code block",
|
||||
"input": '```json\n{"title": "Wrapped", "summary": "In code block"}\n```',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Mixed quotes and apostrophes",
|
||||
"input": '{"title": "Alice\'s "big idea" presentation", "summary": "She said it\'s "revolutionary""}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "JSON with newlines",
|
||||
"input": '{"title": "Multi-line", "summary": "Line 1\\nLine 2 with \\"quotes\\"\\nLine 3"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "LLM response with explanation",
|
||||
"input": '''Here's the JSON:
|
||||
```json
|
||||
{
|
||||
"title": "Q3 Planning",
|
||||
"summary": "We discussed the "new initiative" for next quarter."
|
||||
}
|
||||
```
|
||||
Hope this helps!''',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Speaker identification with quotes",
|
||||
"input": '{"SPEAKER_00": "John Smith", "SPEAKER_01": "Jane "The Expert" Doe", "SPEAKER_02": "Bob"}',
|
||||
"should_succeed": True
|
||||
},
|
||||
{
|
||||
"name": "Completely malformed",
|
||||
"input": '{"title": "Test", "summary": this is not valid json at all}',
|
||||
"should_succeed": False
|
||||
},
|
||||
{
|
||||
"name": "Empty string",
|
||||
"input": "",
|
||||
"should_succeed": False
|
||||
}
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for i, test_case in enumerate(test_cases, 1):
|
||||
print(f"Test {i}: {test_case['name']}")
|
||||
print(f"Input: {test_case['input'][:100]}{'...' if len(test_case['input']) > 100 else ''}")
|
||||
|
||||
try:
|
||||
result = safe_json_loads(test_case['input'], {"error": "fallback"})
|
||||
|
||||
if test_case['should_succeed']:
|
||||
if result != {"error": "fallback"} and isinstance(result, (dict, list)):
|
||||
print("✅ PASSED - Successfully parsed JSON")
|
||||
print(f" Result: {result}")
|
||||
passed += 1
|
||||
else:
|
||||
print("❌ FAILED - Expected successful parsing but got fallback")
|
||||
failed += 1
|
||||
else:
|
||||
if result == {"error": "fallback"}:
|
||||
print("✅ PASSED - Correctly returned fallback for malformed JSON")
|
||||
passed += 1
|
||||
else:
|
||||
print("❌ FAILED - Expected fallback but got parsed result")
|
||||
print(f" Unexpected result: {result}")
|
||||
failed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ FAILED - Exception occurred: {e}")
|
||||
failed += 1
|
||||
|
||||
print("-" * 50)
|
||||
|
||||
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
|
||||
return failed == 0
|
||||
|
||||
def test_preprocessing_function():
|
||||
"""Test the preprocessing function directly."""
|
||||
print("\n🔧 Testing preprocessing function directly...\n")
|
||||
|
||||
test_input = '{"title": "Meeting about "Project X"", "summary": "Discussion summary"}'
|
||||
print(f"Original: {test_input}")
|
||||
|
||||
processed = preprocess_json_escapes(test_input)
|
||||
print(f"Processed: {processed}")
|
||||
|
||||
try:
|
||||
result = json.loads(processed)
|
||||
print(f"✅ Successfully parsed: {result}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ Still failed: {e}")
|
||||
|
||||
def test_specific_scenarios():
|
||||
"""Test specific real-world scenarios."""
|
||||
print("\n🎯 Testing specific LLM response scenarios...\n")
|
||||
|
||||
# Test case from the original issue
|
||||
gemini_response = '''{"title": "Meeting about "Project Phoenix" and budget allocation", "summary": "The team discussed John's "breakthrough idea" and Mary said "this will change everything" during the Q3 planning session."}'''
|
||||
|
||||
print("Testing Gemini-style response with unescaped quotes:")
|
||||
print(f"Input: {gemini_response}")
|
||||
|
||||
# Test preprocessing directly
|
||||
processed = preprocess_json_escapes(gemini_response)
|
||||
print(f"Processed: {processed}")
|
||||
|
||||
result = safe_json_loads(gemini_response)
|
||||
if isinstance(result, dict) and "title" in result and "summary" in result:
|
||||
print("✅ SUCCESS - Parsed Gemini response correctly!")
|
||||
print(f"Title: {result['title']}")
|
||||
print(f"Summary: {result['summary'][:100]}...")
|
||||
else:
|
||||
print("❌ FAILED - Could not parse Gemini response")
|
||||
print(f"Result: {result}")
|
||||
|
||||
print("-" * 50)
|
||||
|
||||
# Test speaker identification scenario
|
||||
speaker_response = '''{"SPEAKER_00": "John "The Manager" Smith", "SPEAKER_01": "Alice Johnson", "SPEAKER_02": "Bob "Tech Lead" Wilson"}'''
|
||||
|
||||
print("Testing speaker identification with quotes in names:")
|
||||
print(f"Input: {speaker_response}")
|
||||
|
||||
# Test preprocessing directly
|
||||
processed = preprocess_json_escapes(speaker_response)
|
||||
print(f"Processed: {processed}")
|
||||
|
||||
result = safe_json_loads(speaker_response)
|
||||
if isinstance(result, dict) and len(result) >= 3:
|
||||
print("✅ SUCCESS - Parsed speaker identification correctly!")
|
||||
for speaker, name in result.items():
|
||||
print(f" {speaker}: {name}")
|
||||
else:
|
||||
print("❌ FAILED - Could not parse speaker identification")
|
||||
print(f"Result: {result}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Starting Standalone JSON Preprocessing Tests\n")
|
||||
|
||||
# Test preprocessing function directly
|
||||
test_preprocessing_function()
|
||||
|
||||
# Run the comprehensive test
|
||||
success = run_comprehensive_test()
|
||||
|
||||
# Test specific scenarios
|
||||
test_specific_scenarios()
|
||||
|
||||
if success:
|
||||
print("\n🎉 All tests completed successfully! JSON preprocessing should handle LLM response issues gracefully.")
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. The implementation may need refinement.")
|
||||
251
tests/test_migration_compatibility.py
Normal file
251
tests/test_migration_compatibility.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Test suite to ensure database migrations are compatible with both SQLite and PostgreSQL.
|
||||
|
||||
These tests scan the init_db.py file for patterns that would break on PostgreSQL,
|
||||
such as SQLite-only boolean defaults (0/1 instead of FALSE/TRUE) and unquoted
|
||||
reserved keywords.
|
||||
|
||||
Run with: python tests/test_migration_compatibility.py
|
||||
"""
|
||||
|
||||
import re
|
||||
import unittest
|
||||
import os
|
||||
|
||||
|
||||
class TestMigrationCompatibility(unittest.TestCase):
|
||||
"""Tests to ensure init_db.py uses cross-database compatible SQL."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Load init_db.py content once for all tests."""
|
||||
# Find the project root
|
||||
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(test_dir)
|
||||
init_db_path = os.path.join(project_root, 'src', 'init_db.py')
|
||||
|
||||
with open(init_db_path, 'r') as f:
|
||||
cls.content = f.read()
|
||||
|
||||
def test_no_raw_boolean_defaults_in_alter_table(self):
|
||||
"""
|
||||
Ensure no raw ALTER TABLE statements use SQLite-only boolean defaults.
|
||||
|
||||
The pattern 'BOOLEAN DEFAULT 0' or 'BOOLEAN DEFAULT 1' in raw SQL
|
||||
will fail on PostgreSQL, which requires 'DEFAULT FALSE' or 'DEFAULT TRUE'.
|
||||
|
||||
Using add_column_if_not_exists() handles this conversion automatically.
|
||||
"""
|
||||
# Pattern to find raw SQL with text() that has BOOLEAN DEFAULT 0/1
|
||||
# This matches: text('... BOOLEAN DEFAULT 0 ...') or text("...")
|
||||
pattern = r"conn\.execute\s*\(\s*text\s*\(['\"]([^'\"]*BOOLEAN\s+DEFAULT\s+[01][^'\"]*)['\"]"
|
||||
|
||||
matches = re.findall(pattern, self.content, re.IGNORECASE)
|
||||
|
||||
# Filter out false positives - we're looking for raw ALTER TABLE statements
|
||||
# not UPDATE statements or other SQL that legitimately uses 0/1
|
||||
problematic = []
|
||||
for match in matches:
|
||||
match_upper = match.upper()
|
||||
# Only flag if it's an ALTER TABLE with BOOLEAN DEFAULT 0/1
|
||||
if 'ALTER TABLE' in match_upper and 'BOOLEAN' in match_upper:
|
||||
if 'DEFAULT 0' in match or 'DEFAULT 1' in match:
|
||||
problematic.append(match)
|
||||
|
||||
self.assertEqual(
|
||||
len(problematic), 0,
|
||||
f"Found SQLite-only boolean defaults in raw ALTER TABLE statements. "
|
||||
f"Use add_column_if_not_exists() instead:\n" +
|
||||
"\n".join(f" - {m[:100]}..." if len(m) > 100 else f" - {m}" for m in problematic)
|
||||
)
|
||||
|
||||
def test_no_boolean_integer_comparisons_in_raw_sql(self):
|
||||
"""
|
||||
Ensure raw SQL doesn't compare boolean columns to integers (0/1).
|
||||
|
||||
PostgreSQL strictly separates boolean and integer types:
|
||||
- 'column = 1' fails with 'operator does not exist: boolean = integer'
|
||||
- 'column = TRUE' works on both SQLite (3.23+) and PostgreSQL
|
||||
|
||||
Known boolean columns in migrations: protect_from_deletion, email_verified,
|
||||
auto_share_on_apply, share_with_group_lead, is_inbox, is_highlighted,
|
||||
deletion_exempt, is_admin, can_share_publicly.
|
||||
"""
|
||||
boolean_columns = [
|
||||
'protect_from_deletion', 'email_verified', 'auto_share_on_apply',
|
||||
'share_with_group_lead', 'is_inbox', 'is_highlighted',
|
||||
'deletion_exempt', 'is_admin', 'can_share_publicly',
|
||||
'auto_speaker_labelling', 'auto_summarization'
|
||||
]
|
||||
|
||||
# Find raw SQL in text() calls
|
||||
sql_pattern = r"text\s*\(\s*['\"\"]\"\"(.*?)['\"\"]\"\"?\s*\)"
|
||||
# Simpler: find lines with known boolean column = 0 or = 1
|
||||
problematic = []
|
||||
for col in boolean_columns:
|
||||
# Match: column = 0 or column = 1 (not = TRUE/FALSE)
|
||||
pattern = rf"{col}\s*=\s*[01]\b"
|
||||
matches = re.finditer(pattern, self.content, re.IGNORECASE)
|
||||
for match in matches:
|
||||
# Get surrounding context to check if it's in a text() SQL call
|
||||
start = max(0, match.start() - 200)
|
||||
context = self.content[start:match.end() + 50]
|
||||
if 'text(' in context and 'sqlite_master' not in context:
|
||||
problematic.append(f"{col}: ...{match.group()}...")
|
||||
|
||||
self.assertEqual(
|
||||
len(problematic), 0,
|
||||
f"Found boolean columns compared to integers in raw SQL. "
|
||||
f"Use TRUE/FALSE instead of 1/0 for PostgreSQL compatibility:\n" +
|
||||
"\n".join(f" - {p}" for p in problematic)
|
||||
)
|
||||
|
||||
def test_reserved_keywords_quoted_in_index_creation(self):
|
||||
"""
|
||||
Ensure reserved keywords like 'user' are properly quoted in index creation.
|
||||
|
||||
Raw SQL like 'CREATE INDEX ... ON user (column)' will fail on some databases
|
||||
because 'user' is a reserved keyword. It should be quoted as "user" or use
|
||||
the create_index_if_not_exists() utility.
|
||||
"""
|
||||
reserved_keywords = ['user', 'order', 'group', 'table', 'select', 'index']
|
||||
|
||||
problematic = []
|
||||
|
||||
for keyword in reserved_keywords:
|
||||
# Pattern to find unquoted reserved keyword after ON in index creation
|
||||
# Matches: CREATE INDEX ... ON user ( but not ON "user" or ON `user`
|
||||
pattern = rf"CREATE\s+(?:UNIQUE\s+)?INDEX[^;]*\s+ON\s+{keyword}\s*\("
|
||||
|
||||
matches = re.findall(pattern, self.content, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
# Skip if the keyword is already quoted
|
||||
if f'"{keyword}"' in match.lower() or f'`{keyword}`' in match.lower():
|
||||
continue
|
||||
problematic.append((keyword, match[:80]))
|
||||
|
||||
self.assertEqual(
|
||||
len(problematic), 0,
|
||||
f"Found unquoted reserved keywords in index creation. "
|
||||
f"Use create_index_if_not_exists() or quote the table name:\n" +
|
||||
"\n".join(f" - '{kw}' in: {sql}..." for kw, sql in problematic)
|
||||
)
|
||||
|
||||
def test_add_column_uses_utility(self):
|
||||
"""
|
||||
Ensure most ADD COLUMN operations use add_column_if_not_exists().
|
||||
|
||||
Direct ALTER TABLE ADD COLUMN statements should use the utility function
|
||||
to ensure cross-database compatibility with boolean defaults and quoting.
|
||||
"""
|
||||
# Count direct ALTER TABLE ADD COLUMN in text() calls
|
||||
direct_pattern = r"conn\.execute\s*\(\s*text\s*\(['\"][^'\"]*ALTER\s+TABLE[^'\"]*ADD\s+COLUMN"
|
||||
direct_matches = re.findall(direct_pattern, self.content, re.IGNORECASE)
|
||||
|
||||
# Count uses of add_column_if_not_exists
|
||||
utility_pattern = r"add_column_if_not_exists\s*\("
|
||||
utility_matches = re.findall(utility_pattern, self.content)
|
||||
|
||||
# We expect most ADD COLUMN operations to use the utility
|
||||
# Allow some direct usage for special cases (e.g., table recreation)
|
||||
# but utility usage should significantly outnumber direct usage
|
||||
self.assertGreater(
|
||||
len(utility_matches), len(direct_matches),
|
||||
f"Found {len(direct_matches)} direct ALTER TABLE ADD COLUMN statements "
|
||||
f"vs {len(utility_matches)} add_column_if_not_exists() calls. "
|
||||
f"Consider using the utility function for cross-database compatibility."
|
||||
)
|
||||
|
||||
def test_incompatible_types_handled_by_utility(self):
|
||||
"""
|
||||
Ensure columns with PostgreSQL-incompatible types (DATETIME, BLOB) are
|
||||
added through add_column_if_not_exists() which auto-converts them,
|
||||
and NOT via raw ALTER TABLE statements that would bypass conversion.
|
||||
|
||||
PostgreSQL type differences:
|
||||
- DATETIME -> TIMESTAMP
|
||||
- BLOB -> BYTEA
|
||||
"""
|
||||
incompatible_types = ['DATETIME', 'BLOB']
|
||||
|
||||
# Check for raw ALTER TABLE statements using incompatible types
|
||||
for sql_type in incompatible_types:
|
||||
pattern = rf"conn\.execute\s*\(\s*text\s*\(['\"][^'\"]*ALTER\s+TABLE[^'\"]*\b{sql_type}\b[^'\"]*['\"]"
|
||||
matches = re.findall(pattern, self.content, re.IGNORECASE)
|
||||
|
||||
self.assertEqual(
|
||||
len(matches), 0,
|
||||
f"Found raw ALTER TABLE statements using '{sql_type}' which is incompatible with PostgreSQL. "
|
||||
f"Use add_column_if_not_exists() which auto-converts types:\n" +
|
||||
"\n".join(f" - {m[:100]}..." if len(m) > 100 else f" - {m}" for m in matches)
|
||||
)
|
||||
|
||||
# Verify that add_column_if_not_exists calls using these types exist
|
||||
# (confirming they go through the utility which handles conversion)
|
||||
for sql_type in incompatible_types:
|
||||
pattern = rf"add_column_if_not_exists\s*\([^)]*['\"]({sql_type})['\"]"
|
||||
matches = re.findall(pattern, self.content, re.IGNORECASE)
|
||||
# Just informational - these are fine because the utility converts them
|
||||
|
||||
def test_no_double_quoted_string_defaults(self):
|
||||
"""
|
||||
Ensure no SQL DEFAULT values use double-quoted strings.
|
||||
|
||||
In SQL, double quotes denote identifiers (column/table names), not string
|
||||
literals. SQLite tolerates this, but PostgreSQL will interpret DEFAULT "en"
|
||||
as a reference to a column named "en" and fail with 'column "en" does not exist'.
|
||||
|
||||
String defaults must use single quotes: DEFAULT 'en'
|
||||
"""
|
||||
# Match DEFAULT followed by a double-quoted string value
|
||||
pattern = r'DEFAULT\s+"[^"]*"'
|
||||
|
||||
lines = self.content.splitlines()
|
||||
problematic = []
|
||||
for i, line in enumerate(lines, 1):
|
||||
if re.search(pattern, line, re.IGNORECASE):
|
||||
problematic.append(f" Line {i}: {line.strip()}")
|
||||
|
||||
self.assertEqual(
|
||||
len(problematic), 0,
|
||||
f"Found double-quoted string defaults in init_db.py. "
|
||||
f"PostgreSQL interprets double quotes as column identifiers, not string literals. "
|
||||
f"Use single quotes instead (e.g., DEFAULT 'en' not DEFAULT \"en\"):\n" +
|
||||
"\n".join(problematic)
|
||||
)
|
||||
|
||||
def test_create_index_uses_utility_for_user_table(self):
|
||||
"""
|
||||
Ensure index creation on 'user' table uses create_index_if_not_exists().
|
||||
|
||||
The 'user' table name is a reserved keyword that requires special quoting.
|
||||
Using create_index_if_not_exists() handles this automatically.
|
||||
"""
|
||||
# Find all index creation on user table
|
||||
pattern = r"CREATE\s+(?:UNIQUE\s+)?INDEX[^;]*ON\s+[\"'`]?user[\"'`]?\s*\("
|
||||
|
||||
# Count raw index creation on user table in text() calls
|
||||
raw_pattern = r"conn\.execute\s*\(\s*text\s*\(['\"][^'\"]*CREATE\s+(?:UNIQUE\s+)?INDEX[^'\"]*ON\s+[\"'`]?user"
|
||||
raw_matches = re.findall(raw_pattern, self.content, re.IGNORECASE)
|
||||
|
||||
# Count uses of create_index_if_not_exists for user table
|
||||
utility_pattern = r"create_index_if_not_exists\s*\([^)]*['\"]user['\"]"
|
||||
utility_matches = re.findall(utility_pattern, self.content, re.IGNORECASE)
|
||||
|
||||
# All index creation on user table should use the utility
|
||||
# (excluding table recreation scenarios which have their own quoting)
|
||||
if len(raw_matches) > 0:
|
||||
# Check if these are in table recreation blocks (acceptable)
|
||||
table_recreation_pattern = r"CREATE\s+TABLE\s+user_new"
|
||||
has_table_recreation = re.search(table_recreation_pattern, self.content, re.IGNORECASE)
|
||||
|
||||
if not has_table_recreation or len(raw_matches) > 1:
|
||||
self.fail(
|
||||
f"Found {len(raw_matches)} raw CREATE INDEX statements on 'user' table. "
|
||||
f"Use create_index_if_not_exists() for proper quoting of reserved keywords."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
165
tests/test_postgres_migrations.py
Normal file
165
tests/test_postgres_migrations.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Integration test for database migrations against a real database engine.
|
||||
|
||||
Runs initialize_database() and verifies that all tables and critical columns
|
||||
are created successfully. Works with both SQLite (default, for local runs)
|
||||
and PostgreSQL (when TEST_DATABASE_URI env var is set).
|
||||
|
||||
IMPORTANT: This test uses TEST_DATABASE_URI (not SQLALCHEMY_DATABASE_URI) to
|
||||
avoid accidentally connecting to and destroying a real application database.
|
||||
|
||||
Usage:
|
||||
# Local (SQLite in-memory, safe):
|
||||
python tests/test_postgres_migrations.py
|
||||
|
||||
# Against PostgreSQL (CI or explicit testing):
|
||||
TEST_DATABASE_URI=postgresql://user:pass@localhost:5432/testdb \
|
||||
python tests/test_postgres_migrations.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
# Ensure project root is on the path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from flask import Flask
|
||||
from src.database import db
|
||||
# Importing models registers them with SQLAlchemy so create_all() builds all tables
|
||||
import src.models # noqa: F401
|
||||
from src.init_db import initialize_database
|
||||
|
||||
|
||||
def create_test_app():
|
||||
"""Create a minimal Flask app for testing database operations.
|
||||
|
||||
Uses TEST_DATABASE_URI env var (NOT SQLALCHEMY_DATABASE_URI) to prevent
|
||||
accidental connection to production/dev databases.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
app.config['SQLALCHEMY_DATABASE_URI'] = os.environ.get(
|
||||
'TEST_DATABASE_URI', 'sqlite://' # in-memory SQLite by default
|
||||
)
|
||||
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||
app.config['TESTING'] = True
|
||||
db.init_app(app)
|
||||
return app
|
||||
|
||||
|
||||
class TestDatabaseMigrations(unittest.TestCase):
|
||||
"""Test that initialize_database() runs cleanly against the configured DB engine."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.app = create_test_app()
|
||||
with cls.app.app_context():
|
||||
initialize_database(cls.app)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
with cls.app.app_context():
|
||||
# Use raw DROP to avoid circular FK dependency errors in SQLAlchemy's
|
||||
# drop_all() (user <-> naming_template have mutual foreign keys)
|
||||
from sqlalchemy import inspect, text
|
||||
tables = inspect(db.engine).get_table_names()
|
||||
with db.engine.connect() as conn:
|
||||
if db.engine.name == 'postgresql':
|
||||
for table in tables:
|
||||
conn.execute(text(f'DROP TABLE IF EXISTS "{table}" CASCADE'))
|
||||
else:
|
||||
conn.execute(text('PRAGMA foreign_keys = OFF'))
|
||||
for table in tables:
|
||||
conn.execute(text(f'DROP TABLE IF EXISTS "{table}"'))
|
||||
conn.execute(text('PRAGMA foreign_keys = ON'))
|
||||
conn.commit()
|
||||
|
||||
def _get_table_names(self):
|
||||
from sqlalchemy import inspect
|
||||
inspector = inspect(db.engine)
|
||||
return inspector.get_table_names()
|
||||
|
||||
def _get_column_names(self, table):
|
||||
from sqlalchemy import inspect
|
||||
inspector = inspect(db.engine)
|
||||
return [col['name'] for col in inspector.get_columns(table)]
|
||||
|
||||
def test_core_tables_exist(self):
|
||||
"""Verify that all core tables were created."""
|
||||
with self.app.app_context():
|
||||
tables = self._get_table_names()
|
||||
expected_tables = [
|
||||
'user', 'recording', 'transcript_chunk', 'tag',
|
||||
'folder', 'share', 'internal_share', 'system_setting',
|
||||
'speaker', 'processing_job', 'group', 'group_membership',
|
||||
]
|
||||
for table in expected_tables:
|
||||
self.assertIn(table, tables, f"Missing table: {table}")
|
||||
|
||||
def test_user_migration_columns(self):
|
||||
"""Verify columns added by migrations exist on the user table."""
|
||||
with self.app.app_context():
|
||||
columns = self._get_column_names('user')
|
||||
expected = [
|
||||
'id', 'username', 'email', 'password',
|
||||
'transcription_language', 'output_language', 'ui_language',
|
||||
'summary_prompt', 'extract_events', 'name', 'job_title',
|
||||
'company', 'diarize', 'sso_provider', 'sso_subject',
|
||||
'can_share_publicly', 'monthly_token_budget',
|
||||
'monthly_transcription_budget', 'email_verified',
|
||||
'auto_speaker_labelling', 'auto_summarization',
|
||||
]
|
||||
for col in expected:
|
||||
self.assertIn(col, columns, f"Missing user column: {col}")
|
||||
|
||||
def test_recording_migration_columns(self):
|
||||
"""Verify columns added by migrations exist on the recording table."""
|
||||
with self.app.app_context():
|
||||
columns = self._get_column_names('recording')
|
||||
expected = [
|
||||
'id', 'is_inbox', 'is_highlighted', 'mime_type',
|
||||
'completed_at', 'processing_time_seconds', 'error_message',
|
||||
'folder_id', 'audio_deleted_at', 'deletion_exempt',
|
||||
'speaker_embeddings',
|
||||
]
|
||||
for col in expected:
|
||||
self.assertIn(col, columns, f"Missing recording column: {col}")
|
||||
|
||||
def test_tag_migration_columns(self):
|
||||
"""Verify columns added by migrations exist on the tag table."""
|
||||
with self.app.app_context():
|
||||
columns = self._get_column_names('tag')
|
||||
expected = [
|
||||
'id', 'protect_from_deletion', 'group_id',
|
||||
'retention_days', 'auto_share_on_apply',
|
||||
'naming_template_id', 'export_template_id',
|
||||
]
|
||||
for col in expected:
|
||||
self.assertIn(col, columns, f"Missing tag column: {col}")
|
||||
|
||||
def test_system_settings_initialized(self):
|
||||
"""Verify that default system settings were created."""
|
||||
with self.app.app_context():
|
||||
from src.models import SystemSetting
|
||||
expected_keys = [
|
||||
'transcript_length_limit', 'max_file_size_mb',
|
||||
'asr_timeout_seconds', 'disable_auto_summarization',
|
||||
'enable_folders',
|
||||
]
|
||||
for key in expected_keys:
|
||||
setting = SystemSetting.query.filter_by(key=key).first()
|
||||
self.assertIsNotNone(setting, f"Missing system setting: {key}")
|
||||
|
||||
def test_engine_type_matches_expectation(self):
|
||||
"""Sanity check: confirm we're testing against the expected engine."""
|
||||
with self.app.app_context():
|
||||
uri = self.app.config['SQLALCHEMY_DATABASE_URI']
|
||||
engine_name = db.engine.name
|
||||
if uri.startswith('postgresql'):
|
||||
self.assertEqual(engine_name, 'postgresql')
|
||||
else:
|
||||
self.assertEqual(engine_name, 'sqlite')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
435
tests/test_video_passthrough.py
Normal file
435
tests/test_video_passthrough.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Test suite for the VIDEO_PASSTHROUGH_ASR feature.
|
||||
|
||||
Tests configuration, code path correctness, and interaction with VIDEO_RETENTION
|
||||
across all entry points (processing pipeline, upload handler, file monitor, incognito).
|
||||
Uses static analysis — no running server or real video files required.
|
||||
|
||||
Run with: python tests/test_video_passthrough.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
PROJECT_ROOT = os.path.dirname(TEST_DIR)
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
def read_file(rel_path):
|
||||
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
# Cache file contents once — they don't change during the run
|
||||
PROCESSING = read_file('src/tasks/processing.py')
|
||||
RECORDINGS = read_file('src/api/recordings.py')
|
||||
FILE_MONITOR = read_file('src/file_monitor.py')
|
||||
APP_CONFIG = read_file('src/config/app_config.py')
|
||||
ENV_EXAMPLE = read_file('config/env.transcription.example')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_function_body(source, func_name):
|
||||
"""Extract the body of a top-level function from source code."""
|
||||
pattern = rf'^def {func_name}\('
|
||||
lines = source.split('\n')
|
||||
start = None
|
||||
for i, line in enumerate(lines):
|
||||
if re.match(pattern, line):
|
||||
start = i
|
||||
break
|
||||
if start is None:
|
||||
return ''
|
||||
# Collect until next top-level def or class or EOF
|
||||
body_lines = [lines[start]]
|
||||
for line in lines[start + 1:]:
|
||||
if line and not line[0].isspace() and (line.startswith('def ') or line.startswith('class ')):
|
||||
break
|
||||
body_lines.append(line)
|
||||
return '\n'.join(body_lines)
|
||||
|
||||
|
||||
def split_at_incognito(source):
|
||||
"""Split processing.py into main and incognito sections."""
|
||||
marker = 'def transcribe_incognito('
|
||||
idx = source.find(marker)
|
||||
if idx == -1:
|
||||
return source, ''
|
||||
return source[:idx], source[idx:]
|
||||
|
||||
|
||||
PROCESSING_MAIN, PROCESSING_INCOGNITO = split_at_incognito(PROCESSING)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. Configuration
|
||||
# ===========================================================================
|
||||
|
||||
class TestPassthroughConfig(unittest.TestCase):
|
||||
"""VIDEO_PASSTHROUGH_ASR env var is defined and defaults to false."""
|
||||
|
||||
FILES_THAT_NEED_IT = [
|
||||
('src/config/app_config.py', APP_CONFIG),
|
||||
('src/tasks/processing.py', PROCESSING),
|
||||
('src/api/recordings.py', RECORDINGS),
|
||||
('src/file_monitor.py', FILE_MONITOR),
|
||||
]
|
||||
|
||||
def test_defined_in_all_files(self):
|
||||
for rel_path, content in self.FILES_THAT_NEED_IT:
|
||||
with self.subTest(file=rel_path):
|
||||
self.assertIn('VIDEO_PASSTHROUGH_ASR', content,
|
||||
f"VIDEO_PASSTHROUGH_ASR missing from {rel_path}")
|
||||
|
||||
def test_default_is_false_everywhere(self):
|
||||
for rel_path, content in self.FILES_THAT_NEED_IT:
|
||||
match = re.search(
|
||||
r"VIDEO_PASSTHROUGH_ASR\s*=\s*os\.environ\.get\('VIDEO_PASSTHROUGH_ASR',\s*'(\w+)'\)",
|
||||
content
|
||||
)
|
||||
if match:
|
||||
with self.subTest(file=rel_path):
|
||||
self.assertEqual(match.group(1), 'false',
|
||||
f"Default should be 'false' in {rel_path}")
|
||||
|
||||
def test_canonical_definition_in_app_config(self):
|
||||
self.assertIn(
|
||||
"VIDEO_PASSTHROUGH_ASR = os.environ.get('VIDEO_PASSTHROUGH_ASR', 'false').lower() == 'true'",
|
||||
APP_CONFIG
|
||||
)
|
||||
|
||||
def test_documented_in_env_example(self):
|
||||
self.assertIn('VIDEO_PASSTHROUGH_ASR', ENV_EXAMPLE)
|
||||
|
||||
def test_processing_imports_from_config(self):
|
||||
self.assertIn('VIDEO_PASSTHROUGH_ASR', PROCESSING)
|
||||
# Should import from app_config, not read os.environ directly
|
||||
self.assertIn('import', PROCESSING)
|
||||
# Verify it's in an import line from app_config
|
||||
import_lines = [l for l in PROCESSING.split('\n')
|
||||
if 'from src.config.app_config import' in l]
|
||||
found = any('VIDEO_PASSTHROUGH_ASR' in l for l in import_lines)
|
||||
self.assertTrue(found, "processing.py should import VIDEO_PASSTHROUGH_ASR from app_config")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. Processing pipeline — main transcription path
|
||||
# ===========================================================================
|
||||
|
||||
class TestProcessingMainPath(unittest.TestCase):
|
||||
"""Test transcribe_with_connector() video passthrough code paths."""
|
||||
|
||||
def test_passthrough_branch_exists_before_retention(self):
|
||||
"""VIDEO_PASSTHROUGH_ASR is checked before VIDEO_RETENTION in the is_video block."""
|
||||
# Inside the `if is_video:` block, passthrough should be the first check
|
||||
video_block_start = PROCESSING_MAIN.find('if is_video:')
|
||||
self.assertNotEqual(video_block_start, -1)
|
||||
|
||||
after_video = PROCESSING_MAIN[video_block_start:]
|
||||
passthrough_pos = after_video.find('if VIDEO_PASSTHROUGH_ASR:')
|
||||
retention_pos = after_video.find('elif VIDEO_RETENTION:')
|
||||
|
||||
self.assertNotEqual(passthrough_pos, -1, "Missing VIDEO_PASSTHROUGH_ASR check in is_video block")
|
||||
self.assertNotEqual(retention_pos, -1, "Missing elif VIDEO_RETENTION check")
|
||||
self.assertLess(passthrough_pos, retention_pos,
|
||||
"VIDEO_PASSTHROUGH_ASR should be checked before VIDEO_RETENTION")
|
||||
|
||||
def test_passthrough_does_not_call_extract_audio(self):
|
||||
"""The passthrough branch must not call extract_audio_from_video."""
|
||||
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
|
||||
# Find the passthrough branch (from `if VIDEO_PASSTHROUGH_ASR:` to `elif VIDEO_RETENTION:`)
|
||||
pt_start = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
|
||||
pt_end = video_block.find('elif VIDEO_RETENTION:')
|
||||
passthrough_block = video_block[pt_start:pt_end]
|
||||
self.assertNotIn('extract_audio_from_video', passthrough_block,
|
||||
"Passthrough branch should NOT extract audio")
|
||||
|
||||
def test_passthrough_keeps_original_filepath(self):
|
||||
"""Passthrough sets actual_filepath = filepath (the original video)."""
|
||||
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
|
||||
pt_start = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
|
||||
pt_end = video_block.find('elif VIDEO_RETENTION:')
|
||||
passthrough_block = video_block[pt_start:pt_end]
|
||||
self.assertIn('actual_filepath = filepath', passthrough_block)
|
||||
|
||||
def test_passthrough_with_retention_sets_recording_path(self):
|
||||
"""When both passthrough and retention are on, recording.audio_path is set."""
|
||||
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
|
||||
pt_start = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
|
||||
pt_end = video_block.find('elif VIDEO_RETENTION:')
|
||||
passthrough_block = video_block[pt_start:pt_end]
|
||||
self.assertIn('if VIDEO_RETENTION:', passthrough_block,
|
||||
"Passthrough branch should conditionally handle retention")
|
||||
self.assertIn('recording.audio_path = filepath', passthrough_block)
|
||||
self.assertIn("mimetypes.guess_type(filepath)", passthrough_block)
|
||||
|
||||
def test_video_passthrough_active_flag_set(self):
|
||||
"""video_passthrough_active flag is computed from is_video and VIDEO_PASSTHROUGH_ASR."""
|
||||
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
|
||||
PROCESSING_MAIN)
|
||||
|
||||
def test_conversion_skipped_when_passthrough(self):
|
||||
"""convert_if_needed is inside an else block gated by video_passthrough_active."""
|
||||
self.assertIn('if video_passthrough_active:', PROCESSING_MAIN)
|
||||
# The conversion call should be in the else branch
|
||||
flag_pos = PROCESSING_MAIN.find('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR')
|
||||
after_flag = PROCESSING_MAIN[flag_pos:]
|
||||
passthrough_if = after_flag.find('if video_passthrough_active:')
|
||||
else_pos = after_flag.find('else:', passthrough_if)
|
||||
convert_pos = after_flag.find('convert_if_needed(', else_pos)
|
||||
self.assertGreater(convert_pos, else_pos,
|
||||
"convert_if_needed should be in else branch after passthrough check")
|
||||
|
||||
def test_chunking_skipped_when_passthrough(self):
|
||||
"""Chunking evaluates to False when video_passthrough_active."""
|
||||
# Find the chunking decision area after the flag
|
||||
flag_pos = PROCESSING_MAIN.find('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR')
|
||||
after_flag = PROCESSING_MAIN[flag_pos:]
|
||||
self.assertIn('if video_passthrough_active:\n should_chunk = False', after_flag)
|
||||
|
||||
def test_conversion_still_runs_for_non_passthrough(self):
|
||||
"""convert_if_needed still runs when passthrough is off or file is audio."""
|
||||
# The else branch of the passthrough check should contain convert_if_needed
|
||||
self.assertIn('conversion_result = convert_if_needed(', PROCESSING_MAIN)
|
||||
|
||||
def test_chunking_still_evaluated_for_non_passthrough(self):
|
||||
"""Chunking is still evaluated normally when passthrough is not active."""
|
||||
self.assertIn('chunking_service.needs_chunking(actual_filepath, False, connector_specs)',
|
||||
PROCESSING_MAIN)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. Processing pipeline — VIDEO_RETENTION paths still intact
|
||||
# ===========================================================================
|
||||
|
||||
class TestRetentionNotBroken(unittest.TestCase):
|
||||
"""Existing VIDEO_RETENTION behavior must be preserved."""
|
||||
|
||||
def test_retention_branch_still_extracts_audio(self):
|
||||
"""elif VIDEO_RETENTION branch still calls extract_audio_from_video."""
|
||||
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
|
||||
ret_start = video_block.find('elif VIDEO_RETENTION:')
|
||||
# Find next else: at the same indent level
|
||||
after_ret = video_block[ret_start:]
|
||||
else_pos = after_ret.find('\n else:')
|
||||
retention_block = after_ret[:else_pos] if else_pos != -1 else after_ret[:500]
|
||||
self.assertIn('extract_audio_from_video(filepath, cleanup_original=False)',
|
||||
retention_block)
|
||||
|
||||
def test_default_branch_still_extracts_and_deletes(self):
|
||||
"""The final else branch extracts audio with default cleanup (deletes video)."""
|
||||
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
|
||||
# The last else in the is_video block
|
||||
self.assertIn('extract_audio_from_video(filepath)', video_block)
|
||||
|
||||
def test_temp_audio_cleanup_still_present(self):
|
||||
"""Temp audio from retention is still cleaned up after transcription."""
|
||||
self.assertIn('is_video and VIDEO_RETENTION and audio_filepath', PROCESSING_MAIN)
|
||||
self.assertIn('Cleaned up temp audio from video retention', PROCESSING_MAIN)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. Incognito path
|
||||
# ===========================================================================
|
||||
|
||||
class TestIncognitoPassthrough(unittest.TestCase):
|
||||
"""Test passthrough in the incognito transcription path."""
|
||||
|
||||
def test_passthrough_flag_set_in_incognito(self):
|
||||
"""video_passthrough_active is computed in incognito path."""
|
||||
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
|
||||
PROCESSING_INCOGNITO)
|
||||
|
||||
def test_passthrough_skips_extraction_in_incognito(self):
|
||||
"""When passthrough is on, incognito skips extract_audio_from_video."""
|
||||
# The passthrough branch logs and does NOT extract
|
||||
self.assertIn('[Incognito] Video passthrough: sending original video to ASR',
|
||||
PROCESSING_INCOGNITO)
|
||||
|
||||
def test_passthrough_skips_conversion_in_incognito(self):
|
||||
"""When passthrough is on, incognito skips convert_if_needed."""
|
||||
self.assertIn('[Incognito] Video passthrough: skipping codec conversion',
|
||||
PROCESSING_INCOGNITO)
|
||||
|
||||
def test_passthrough_skips_chunking_in_incognito(self):
|
||||
"""When passthrough is on, incognito chunking is False."""
|
||||
body = PROCESSING_INCOGNITO
|
||||
self.assertIn('if video_passthrough_active:\n should_chunk = False', body)
|
||||
|
||||
def test_incognito_does_not_reference_video_retention(self):
|
||||
"""Incognito path should NOT reference VIDEO_RETENTION (no retention in incognito)."""
|
||||
self.assertNotIn('VIDEO_RETENTION', PROCESSING_INCOGNITO)
|
||||
|
||||
def test_incognito_still_extracts_without_passthrough(self):
|
||||
"""Without passthrough, incognito still extracts audio from video."""
|
||||
self.assertIn('extract_audio_from_video(filepath, cleanup_original=False)',
|
||||
PROCESSING_INCOGNITO)
|
||||
|
||||
def test_incognito_still_converts_without_passthrough(self):
|
||||
"""Without passthrough, incognito still runs convert_if_needed."""
|
||||
self.assertIn('convert_if_needed(', PROCESSING_INCOGNITO)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. Upload handler (recordings.py)
|
||||
# ===========================================================================
|
||||
|
||||
class TestUploadHandlerPassthrough(unittest.TestCase):
|
||||
"""Test recordings.py upload handler respects VIDEO_PASSTHROUGH_ASR."""
|
||||
|
||||
def test_skip_conversion_for_passthrough_video(self):
|
||||
"""Upload handler skips conversion when passthrough or retention + video."""
|
||||
self.assertIn('VIDEO_RETENTION or VIDEO_PASSTHROUGH_ASR) and has_video', RECORDINGS)
|
||||
|
||||
def test_extension_fallback_checks_passthrough(self):
|
||||
"""Extension-based video detection also fires for VIDEO_PASSTHROUGH_ASR."""
|
||||
self.assertIn('VIDEO_RETENTION or VIDEO_PASSTHROUGH_ASR', RECORDINGS)
|
||||
|
||||
def test_convert_if_needed_still_in_else(self):
|
||||
"""convert_if_needed still runs for audio files or when both flags are off."""
|
||||
self.assertIn('convert_if_needed(', RECORDINGS)
|
||||
|
||||
def test_passthrough_log_message(self):
|
||||
"""Upload handler logs which mode caused the skip."""
|
||||
self.assertIn("'VIDEO_PASSTHROUGH_ASR'", RECORDINGS)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. File monitor
|
||||
# ===========================================================================
|
||||
|
||||
class TestFileMonitorPassthrough(unittest.TestCase):
|
||||
"""Test file_monitor.py respects VIDEO_PASSTHROUGH_ASR."""
|
||||
|
||||
def test_passthrough_defined(self):
|
||||
self.assertIn('VIDEO_PASSTHROUGH_ASR', FILE_MONITOR)
|
||||
|
||||
def test_skip_conversion_for_passthrough_or_retention(self):
|
||||
"""File monitor skips conversion when passthrough or retention + video."""
|
||||
self.assertIn('VIDEO_PASSTHROUGH_ASR or VIDEO_RETENTION) and has_video', FILE_MONITOR)
|
||||
|
||||
def test_convert_if_needed_in_else_branch(self):
|
||||
"""convert_if_needed is in the else branch, not inside the skip block."""
|
||||
lines = FILE_MONITOR.split('\n')
|
||||
in_skip_block = False
|
||||
found_else = False
|
||||
for i, line in enumerate(lines):
|
||||
if 'VIDEO_PASSTHROUGH_ASR or VIDEO_RETENTION) and has_video' in line:
|
||||
in_skip_block = True
|
||||
elif in_skip_block and line.strip().startswith('else:'):
|
||||
in_skip_block = False
|
||||
found_else = True
|
||||
elif in_skip_block and 'convert_if_needed' in line:
|
||||
self.fail(f"convert_if_needed inside skip block at line {i + 1}")
|
||||
self.assertTrue(found_else, "Should have else branch after passthrough/retention skip")
|
||||
|
||||
def test_log_distinguishes_passthrough_from_retention(self):
|
||||
"""Log message indicates whether passthrough or retention caused the skip."""
|
||||
self.assertIn("'passthrough'", FILE_MONITOR)
|
||||
self.assertIn("'retention'", FILE_MONITOR)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7. Audio files unaffected by passthrough
|
||||
# ===========================================================================
|
||||
|
||||
class TestAudioUnaffected(unittest.TestCase):
|
||||
"""VIDEO_PASSTHROUGH_ASR must only affect video files, never audio."""
|
||||
|
||||
def test_passthrough_flag_gated_on_is_video(self):
|
||||
"""video_passthrough_active is always `is_video and VIDEO_PASSTHROUGH_ASR`."""
|
||||
# Main path
|
||||
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
|
||||
PROCESSING_MAIN)
|
||||
# Incognito path
|
||||
self.assertIn('video_passthrough_active = is_video and VIDEO_PASSTHROUGH_ASR',
|
||||
PROCESSING_INCOGNITO)
|
||||
|
||||
def test_upload_handler_gated_on_has_video(self):
|
||||
"""Upload handler skip is gated on `has_video`."""
|
||||
self.assertIn('and has_video', RECORDINGS)
|
||||
|
||||
def test_file_monitor_gated_on_has_video(self):
|
||||
"""File monitor skip is gated on `has_video`."""
|
||||
self.assertIn('and has_video', FILE_MONITOR)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 8. Documentation
|
||||
# ===========================================================================
|
||||
|
||||
class TestDocumentation(unittest.TestCase):
|
||||
"""VIDEO_PASSTHROUGH_ASR is documented in all relevant places."""
|
||||
|
||||
DOC_FILES = [
|
||||
'config/env.transcription.example',
|
||||
'docs/admin-guide/system-settings.md',
|
||||
'docs/features.md',
|
||||
'docs/getting-started/installation.md',
|
||||
]
|
||||
|
||||
def test_documented_in_all_relevant_files(self):
|
||||
for rel_path in self.DOC_FILES:
|
||||
content = read_file(rel_path)
|
||||
with self.subTest(file=rel_path):
|
||||
self.assertIn('VIDEO_PASSTHROUGH_ASR', content,
|
||||
f"VIDEO_PASSTHROUGH_ASR missing from {rel_path}")
|
||||
|
||||
def test_env_example_commented_out_by_default(self):
|
||||
"""The env example has the option commented out (opt-in)."""
|
||||
self.assertIn('# VIDEO_PASSTHROUGH_ASR=false', ENV_EXAMPLE)
|
||||
|
||||
def test_docs_warn_about_asr_compatibility(self):
|
||||
"""Docs warn that standard APIs will reject video input."""
|
||||
system_settings = read_file('docs/admin-guide/system-settings.md')
|
||||
installation = read_file('docs/getting-started/installation.md')
|
||||
self.assertIn('reject', system_settings.lower())
|
||||
self.assertIn('reject', installation.lower())
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 9. Interaction matrix — structural verification
|
||||
# ===========================================================================
|
||||
|
||||
class TestInteractionMatrix(unittest.TestCase):
|
||||
"""
|
||||
Verify the 3-way branch structure in processing.py:
|
||||
if VIDEO_PASSTHROUGH_ASR → passthrough
|
||||
elif VIDEO_RETENTION → retention
|
||||
else → default extraction
|
||||
"""
|
||||
|
||||
def test_three_way_branch_in_main_path(self):
|
||||
"""Main path has if/elif/else for passthrough/retention/default."""
|
||||
video_block = PROCESSING_MAIN[PROCESSING_MAIN.find('if is_video:'):]
|
||||
# All three branches present in order
|
||||
pt_pos = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
|
||||
ret_pos = video_block.find('elif VIDEO_RETENTION:')
|
||||
else_pos = video_block.find('\n else:', ret_pos)
|
||||
self.assertNotEqual(pt_pos, -1)
|
||||
self.assertNotEqual(ret_pos, -1)
|
||||
self.assertNotEqual(else_pos, -1)
|
||||
self.assertLess(pt_pos, ret_pos)
|
||||
self.assertLess(ret_pos, else_pos)
|
||||
|
||||
def test_incognito_two_way_branch(self):
|
||||
"""Incognito has if/else for passthrough/extract (no retention)."""
|
||||
video_block = PROCESSING_INCOGNITO[PROCESSING_INCOGNITO.find('if is_video:'):]
|
||||
pt_pos = video_block.find('if VIDEO_PASSTHROUGH_ASR:')
|
||||
else_pos = video_block.find('\n else:', pt_pos)
|
||||
self.assertNotEqual(pt_pos, -1)
|
||||
self.assertNotEqual(else_pos, -1)
|
||||
# No VIDEO_RETENTION in incognito
|
||||
incognito_video_block = video_block[:500]
|
||||
self.assertNotIn('VIDEO_RETENTION', incognito_video_block)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
370
tests/test_video_retention.py
Normal file
370
tests/test_video_retention.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Test suite for the VIDEO_RETENTION feature.
|
||||
|
||||
Tests code paths, configuration, and template correctness for video retention.
|
||||
Does NOT require a running server or real video files - uses static analysis
|
||||
and mocking where possible.
|
||||
|
||||
Run with: python tests/test_video_retention.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
# Find project root
|
||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
PROJECT_ROOT = os.path.dirname(TEST_DIR)
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
class TestVideoRetentionConfig(unittest.TestCase):
|
||||
"""Test that VIDEO_RETENTION env var is read correctly everywhere."""
|
||||
|
||||
ALL_FILES = [
|
||||
'src/app.py',
|
||||
'src/tasks/processing.py',
|
||||
'src/api/system.py',
|
||||
'src/api/recordings.py',
|
||||
'src/file_monitor.py',
|
||||
]
|
||||
|
||||
def _read_file(self, rel_path):
|
||||
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def test_env_var_read_in_all_entry_points(self):
|
||||
"""VIDEO_RETENTION env var is read in all files that need it."""
|
||||
for rel_path in self.ALL_FILES:
|
||||
content = self._read_file(rel_path)
|
||||
self.assertIn("VIDEO_RETENTION", content, f"VIDEO_RETENTION missing from {rel_path}")
|
||||
|
||||
def test_exposed_in_api_config(self):
|
||||
"""VIDEO_RETENTION is exposed in the /api/config response."""
|
||||
content = self._read_file('src/api/system.py')
|
||||
self.assertIn("'video_retention': VIDEO_RETENTION", content)
|
||||
|
||||
def test_default_is_false(self):
|
||||
"""All VIDEO_RETENTION reads default to 'false'."""
|
||||
for rel_path in self.ALL_FILES:
|
||||
content = self._read_file(rel_path)
|
||||
match = re.search(r"VIDEO_RETENTION\s*=\s*os\.environ\.get\('VIDEO_RETENTION',\s*'(\w+)'\)", content)
|
||||
if match:
|
||||
self.assertEqual(match.group(1), 'false', f"Default should be 'false' in {rel_path}")
|
||||
|
||||
|
||||
class TestProcessingPipelineVideoRetention(unittest.TestCase):
|
||||
"""Test processing.py video retention code paths via static analysis."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(os.path.join(PROJECT_ROOT, 'src/tasks/processing.py'), 'r') as f:
|
||||
cls.content = f.read()
|
||||
|
||||
def test_video_retention_true_keeps_original(self):
|
||||
"""When VIDEO_RETENTION=True, recording.audio_path is set to original filepath."""
|
||||
# The VIDEO_RETENTION=True branch should set recording.audio_path = filepath
|
||||
self.assertIn('recording.audio_path = filepath', self.content)
|
||||
|
||||
def test_video_retention_true_extracts_without_cleanup(self):
|
||||
"""When VIDEO_RETENTION=True, extract_audio_from_video is called with cleanup_original=False."""
|
||||
self.assertIn('extract_audio_from_video(filepath, cleanup_original=False)', self.content)
|
||||
|
||||
def test_video_retention_false_extracts_with_cleanup(self):
|
||||
"""When VIDEO_RETENTION=False, extract_audio_from_video is called with default cleanup."""
|
||||
self.assertIn('extract_audio_from_video(filepath)', self.content)
|
||||
|
||||
def test_temp_audio_cleanup_after_transcription(self):
|
||||
"""Temp audio from video retention is cleaned up after transcription."""
|
||||
self.assertIn('is_video and VIDEO_RETENTION and audio_filepath', self.content)
|
||||
self.assertIn('Cleaned up temp audio from video retention', self.content)
|
||||
|
||||
def test_audio_filepath_initialized_to_none(self):
|
||||
"""audio_filepath is initialized to None before the is_video check."""
|
||||
# Find the initialization line
|
||||
self.assertIn('audio_filepath = None', self.content)
|
||||
|
||||
def test_video_mime_type_set_for_retention(self):
|
||||
"""When retaining video, mime_type reflects actual video type."""
|
||||
self.assertIn("mimetypes.guess_type(filepath)[0] or 'video/mp4'", self.content)
|
||||
|
||||
def test_duration_uses_recording_audio_path(self):
|
||||
"""Duration lookup uses recording.audio_path (always valid), not filepath."""
|
||||
self.assertIn('chunking_service.get_audio_duration(recording.audio_path)', self.content)
|
||||
# Should NOT use bare filepath for duration (pre-existing bug was fixed)
|
||||
self.assertNotIn('chunking_service.get_audio_duration(filepath)', self.content)
|
||||
|
||||
|
||||
class TestUploadHandlerVideoRetention(unittest.TestCase):
|
||||
"""Test recordings.py upload handler video retention code paths."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(os.path.join(PROJECT_ROOT, 'src/api/recordings.py'), 'r') as f:
|
||||
cls.content = f.read()
|
||||
|
||||
def test_upload_handler_skips_conversion_for_video_retention(self):
|
||||
"""Upload handler skips convert_if_needed for videos when retention is on."""
|
||||
self.assertIn('VIDEO_RETENTION and has_video', self.content)
|
||||
self.assertIn('skipping conversion', self.content)
|
||||
|
||||
def test_upload_handler_has_video_from_codec_info(self):
|
||||
"""Upload handler reads has_video from codec_info probe."""
|
||||
self.assertIn("has_video = codec_info.get('has_video', False)", self.content)
|
||||
|
||||
def test_convert_if_needed_still_in_else_branch(self):
|
||||
"""convert_if_needed still runs for non-video files or when retention is off."""
|
||||
self.assertIn('convert_if_needed(', self.content)
|
||||
|
||||
def test_processing_pipeline_still_converts_audio(self):
|
||||
"""Processing pipeline runs convert_if_needed on extracted audio (the safety net)."""
|
||||
proc_content = open(os.path.join(PROJECT_ROOT, 'src/tasks/processing.py')).read()
|
||||
# After the video extraction block, convert_if_needed runs on actual_filepath
|
||||
self.assertIn('conversion_result = convert_if_needed(\n'
|
||||
' filepath=actual_filepath,', proc_content)
|
||||
|
||||
|
||||
class TestFileMonitorVideoRetention(unittest.TestCase):
|
||||
"""Test file_monitor.py video retention code paths."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(os.path.join(PROJECT_ROOT, 'src/file_monitor.py'), 'r') as f:
|
||||
cls.content = f.read()
|
||||
|
||||
def test_video_retention_skips_conversion(self):
|
||||
"""When VIDEO_RETENTION=True and has_video=True, convert_if_needed is skipped."""
|
||||
# Should have the guard: if VIDEO_RETENTION and has_video: ... skip conversion
|
||||
self.assertIn('VIDEO_RETENTION and has_video', self.content)
|
||||
self.assertIn('skipping conversion', self.content)
|
||||
|
||||
def test_no_double_extraction(self):
|
||||
"""File monitor does NOT call convert_if_needed for videos when retention is on."""
|
||||
# The convert_if_needed call should be in the else branch
|
||||
lines = self.content.split('\n')
|
||||
in_retention_skip_block = False
|
||||
found_convert_in_else = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if 'VIDEO_RETENTION and has_video' in line and 'if' in line:
|
||||
in_retention_skip_block = True
|
||||
elif in_retention_skip_block and 'else:' in line:
|
||||
in_retention_skip_block = False
|
||||
found_convert_in_else = True
|
||||
elif in_retention_skip_block and 'convert_if_needed' in line:
|
||||
self.fail(f"convert_if_needed called inside VIDEO_RETENTION skip block at line {i+1}")
|
||||
|
||||
self.assertTrue(found_convert_in_else, "Should have else branch after video retention skip")
|
||||
|
||||
def test_no_video_retention_param_in_convert_call(self):
|
||||
"""convert_if_needed should NOT receive a video_retention parameter."""
|
||||
# Ensure the old video_retention parameter isn't being passed
|
||||
self.assertNotIn('video_retention=VIDEO_RETENTION', self.content)
|
||||
|
||||
|
||||
class TestAudioConversionNotModified(unittest.TestCase):
|
||||
"""Verify audio_conversion.py was fully reverted (no video_retention parameter)."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(os.path.join(PROJECT_ROOT, 'src/utils/audio_conversion.py'), 'r') as f:
|
||||
cls.content = f.read()
|
||||
|
||||
def test_no_video_retention_parameter(self):
|
||||
"""convert_if_needed should not have a video_retention parameter."""
|
||||
self.assertNotIn('video_retention', self.content)
|
||||
|
||||
def test_no_should_delete_original(self):
|
||||
"""No should_delete_original variable should exist."""
|
||||
self.assertNotIn('should_delete_original', self.content)
|
||||
|
||||
|
||||
class TestSendFileConditional(unittest.TestCase):
|
||||
"""Test that send_file calls use conditional=True for range request support."""
|
||||
|
||||
def _read_file(self, rel_path):
|
||||
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def test_recordings_streaming_has_conditional(self):
|
||||
"""Streaming send_file in recordings.py has conditional=True."""
|
||||
content = self._read_file('src/api/recordings.py')
|
||||
# Find the non-download send_file call
|
||||
self.assertIn('send_file(recording.audio_path, conditional=True)', content)
|
||||
|
||||
def test_recordings_download_has_conditional(self):
|
||||
"""Download send_file in recordings.py has conditional=True."""
|
||||
content = self._read_file('src/api/recordings.py')
|
||||
self.assertIn('as_attachment=True, download_name=filename, conditional=True', content)
|
||||
|
||||
def test_shares_has_conditional(self):
|
||||
"""send_file in shares.py has conditional=True."""
|
||||
content = self._read_file('src/api/shares.py')
|
||||
self.assertIn('send_file(recording.audio_path, conditional=True)', content)
|
||||
|
||||
|
||||
class TestFrontendTemplates(unittest.TestCase):
|
||||
"""Test that frontend templates correctly switch between video and audio."""
|
||||
|
||||
TEMPLATE_FILES = [
|
||||
'templates/components/detail/desktop-right-panel.html',
|
||||
'templates/components/detail/audio-player.html',
|
||||
'templates/modals/speaker-modal.html',
|
||||
'templates/share.html',
|
||||
]
|
||||
|
||||
def _read_template(self, rel_path):
|
||||
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def test_all_templates_use_dynamic_component(self):
|
||||
"""All player templates use <component :is> for video/audio switching."""
|
||||
for tmpl in self.TEMPLATE_FILES:
|
||||
content = self._read_template(tmpl)
|
||||
self.assertIn("<component :is=", content, f"Missing dynamic component in {tmpl}")
|
||||
self.assertIn("startsWith('video/')", content, f"Missing video/ check in {tmpl}")
|
||||
self.assertIn("</component>", content, f"Missing </component> in {tmpl}")
|
||||
|
||||
def test_no_bare_audio_elements_in_main_players(self):
|
||||
"""Main player templates should not have bare <audio elements (replaced by component)."""
|
||||
for tmpl in self.TEMPLATE_FILES:
|
||||
content = self._read_template(tmpl)
|
||||
# Count <audio and <component :is occurrences
|
||||
audio_count = content.count('<audio ')
|
||||
component_count = content.count('<component :is=')
|
||||
|
||||
# Each template should have component :is but no bare <audio for the main player
|
||||
self.assertGreater(component_count, 0, f"No <component :is> in {tmpl}")
|
||||
# Desktop right panel and audio player should have 0 bare audio tags
|
||||
if 'desktop-right-panel' in tmpl or 'audio-player' in tmpl:
|
||||
self.assertEqual(audio_count, 0, f"Unexpected bare <audio> in {tmpl}")
|
||||
|
||||
def test_video_element_gets_visible_styling(self):
|
||||
"""When mime_type is video/, the element should be visible (not hidden)."""
|
||||
for tmpl in self.TEMPLATE_FILES:
|
||||
content = self._read_template(tmpl)
|
||||
# Should have conditional class that shows video and hides audio
|
||||
self.assertIn("'w-full rounded-lg", content, f"Missing video styling in {tmpl}")
|
||||
self.assertIn("'hidden'", content, f"Missing hidden fallback for audio in {tmpl}")
|
||||
|
||||
def test_template_div_balance(self):
|
||||
"""Verify player-specific templates have balanced div tags."""
|
||||
# Only check templates we fully control (not share.html which has pre-existing imbalance)
|
||||
balanced_templates = [
|
||||
'templates/components/detail/desktop-right-panel.html',
|
||||
'templates/components/detail/audio-player.html',
|
||||
'templates/modals/speaker-modal.html',
|
||||
]
|
||||
for tmpl in balanced_templates:
|
||||
content = self._read_template(tmpl)
|
||||
opens = content.count('<div')
|
||||
closes = content.count('</div>')
|
||||
self.assertEqual(opens, closes, f"Unbalanced divs in {tmpl}: {opens} opens, {closes} closes")
|
||||
|
||||
|
||||
class TestLocalization(unittest.TestCase):
|
||||
"""Test that video retention localization keys exist in all locale files."""
|
||||
|
||||
LOCALE_DIR = os.path.join(PROJECT_ROOT, 'static', 'locales')
|
||||
|
||||
def test_video_retained_key_in_all_locales(self):
|
||||
"""upload.videoRetained key exists in all locale files."""
|
||||
locale_files = [f for f in os.listdir(self.LOCALE_DIR) if f.endswith('.json')]
|
||||
self.assertGreater(len(locale_files), 0, "No locale files found")
|
||||
|
||||
for locale_file in locale_files:
|
||||
filepath = os.path.join(self.LOCALE_DIR, locale_file)
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.assertIn('upload', data, f"No 'upload' section in {locale_file}")
|
||||
self.assertIn('videoRetained', data['upload'],
|
||||
f"Missing 'videoRetained' key in upload section of {locale_file}")
|
||||
self.assertIsInstance(data['upload']['videoRetained'], str,
|
||||
f"'videoRetained' should be a string in {locale_file}")
|
||||
self.assertGreater(len(data['upload']['videoRetained']), 0,
|
||||
f"'videoRetained' is empty in {locale_file}")
|
||||
|
||||
def test_locale_files_are_valid_json(self):
|
||||
"""All locale files are valid JSON."""
|
||||
locale_files = [f for f in os.listdir(self.LOCALE_DIR) if f.endswith('.json')]
|
||||
for locale_file in locale_files:
|
||||
filepath = os.path.join(self.LOCALE_DIR, locale_file)
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
self.fail(f"Invalid JSON in {locale_file}: {e}")
|
||||
|
||||
|
||||
class TestVideoRetentionMatrix(unittest.TestCase):
|
||||
"""
|
||||
Test the complete 2x2 matrix of (VIDEO_RETENTION x is_video) scenarios
|
||||
by analyzing the code flow statically.
|
||||
"""
|
||||
|
||||
def _read_file(self, rel_path):
|
||||
with open(os.path.join(PROJECT_ROOT, rel_path), 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def test_processing_has_both_branches(self):
|
||||
"""processing.py has both VIDEO_RETENTION=True and False branches for video."""
|
||||
content = self._read_file('src/tasks/processing.py')
|
||||
# Should have if VIDEO_RETENTION: ... else: ... inside if is_video:
|
||||
self.assertIn('if VIDEO_RETENTION:', content)
|
||||
# After the retention block, should have else for the default behavior
|
||||
lines = content.split('\n')
|
||||
found_retention_if = False
|
||||
found_else_after = False
|
||||
for line in lines:
|
||||
if 'if VIDEO_RETENTION:' in line:
|
||||
found_retention_if = True
|
||||
elif found_retention_if and line.strip().startswith('else:'):
|
||||
found_else_after = True
|
||||
break
|
||||
self.assertTrue(found_else_after, "Missing else branch after VIDEO_RETENTION check in processing.py")
|
||||
|
||||
def test_file_monitor_has_both_branches(self):
|
||||
"""file_monitor.py has both video retention skip and normal conversion paths."""
|
||||
content = self._read_file('src/file_monitor.py')
|
||||
self.assertIn('VIDEO_RETENTION and has_video', content)
|
||||
# convert_if_needed should still exist in the else path
|
||||
self.assertIn('convert_if_needed(', content)
|
||||
|
||||
def test_incognito_not_affected(self):
|
||||
"""Incognito processing path should NOT reference VIDEO_RETENTION."""
|
||||
content = self._read_file('src/tasks/processing.py')
|
||||
# Find the incognito section (marked with [Incognito])
|
||||
incognito_section = content[content.find('[Incognito]'):]
|
||||
# VIDEO_RETENTION should not appear in incognito section
|
||||
# (incognito always strips video per the plan)
|
||||
self.assertNotIn('VIDEO_RETENTION', incognito_section,
|
||||
"VIDEO_RETENTION should not be referenced in incognito processing")
|
||||
|
||||
def test_all_three_entry_points_skip_for_video_retention(self):
|
||||
"""All entry points (upload, file monitor, processing) handle VIDEO_RETENTION."""
|
||||
for rel_path, marker in [
|
||||
('src/api/recordings.py', 'VIDEO_RETENTION and has_video'),
|
||||
('src/file_monitor.py', 'VIDEO_RETENTION and has_video'),
|
||||
('src/tasks/processing.py', 'if VIDEO_RETENTION:'),
|
||||
]:
|
||||
content = self._read_file(rel_path)
|
||||
self.assertIn(marker, content, f"Missing video retention guard in {rel_path}")
|
||||
|
||||
def test_convert_if_needed_always_runs_on_transcription_audio(self):
|
||||
"""Processing pipeline always runs convert_if_needed on audio before transcription."""
|
||||
content = self._read_file('src/tasks/processing.py')
|
||||
# The convert_if_needed call on actual_filepath happens AFTER the video
|
||||
# extraction block, regardless of VIDEO_RETENTION setting
|
||||
video_block_pos = content.find('if is_video:')
|
||||
convert_pos = content.find('convert_if_needed(\n filepath=actual_filepath,')
|
||||
self.assertGreater(convert_pos, video_block_pos,
|
||||
"convert_if_needed must run after video extraction block")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user