Initial release: DictIA v0.8.14-alpha (fork de Speakr, AGPL-3.0)
This commit is contained in:
281
tests/test_job_queue_race_condition.py
Normal file
281
tests/test_job_queue_race_condition.py
Normal file
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for job queue race condition fix.
|
||||
|
||||
This script verifies that the atomic job claiming mechanism prevents
|
||||
multiple workers from claiming the same job simultaneously.
|
||||
|
||||
The fix uses an atomic UPDATE with WHERE clause to ensure only one
|
||||
worker can claim a job, even with multiple processes/threads.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
def test_atomic_job_claiming():
|
||||
"""
|
||||
Test that only one worker can claim a job even with concurrent attempts.
|
||||
|
||||
This simulates the race condition where multiple workers try to claim
|
||||
the same job simultaneously.
|
||||
"""
|
||||
print("\n=== Testing Atomic Job Claiming ===\n")
|
||||
|
||||
# Import Flask app and models
|
||||
from src.app import app
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob, User, Recording
|
||||
from sqlalchemy import update
|
||||
|
||||
with app.app_context():
|
||||
# Use the first existing user for testing, or create a minimal test user
|
||||
test_user = User.query.first()
|
||||
if not test_user:
|
||||
test_user = User(
|
||||
username='test_race_condition_user',
|
||||
email='test_race@example.com',
|
||||
password='not_used' # Password not needed for this test
|
||||
)
|
||||
db.session.add(test_user)
|
||||
db.session.commit()
|
||||
|
||||
# Create a test recording
|
||||
test_recording = Recording(
|
||||
user_id=test_user.id,
|
||||
title='Test Race Condition Recording',
|
||||
audio_path='/tmp/test_audio.mp3',
|
||||
status='QUEUED'
|
||||
)
|
||||
db.session.add(test_recording)
|
||||
db.session.commit()
|
||||
|
||||
# Create a test job in 'queued' status
|
||||
test_job = ProcessingJob(
|
||||
recording_id=test_recording.id,
|
||||
user_id=test_user.id,
|
||||
job_type='transcribe',
|
||||
status='queued'
|
||||
)
|
||||
db.session.add(test_job)
|
||||
db.session.commit()
|
||||
|
||||
job_id = test_job.id
|
||||
print(f"Created test job {job_id} with status 'queued'")
|
||||
|
||||
# Track which threads successfully claimed the job
|
||||
successful_claims = []
|
||||
claim_lock = threading.Lock()
|
||||
|
||||
def attempt_claim(worker_id):
|
||||
"""Simulate a worker attempting to claim the job."""
|
||||
with app.app_context():
|
||||
try:
|
||||
# This is the atomic claim logic from the fix
|
||||
claim_time = datetime.utcnow()
|
||||
result = db.session.execute(
|
||||
update(ProcessingJob)
|
||||
.where(
|
||||
ProcessingJob.id == job_id,
|
||||
ProcessingJob.status == 'queued'
|
||||
)
|
||||
.values(status='processing', started_at=claim_time)
|
||||
)
|
||||
|
||||
if result.rowcount == 1:
|
||||
db.session.commit()
|
||||
with claim_lock:
|
||||
successful_claims.append(worker_id)
|
||||
return f"Worker {worker_id}: Successfully claimed job"
|
||||
else:
|
||||
db.session.rollback()
|
||||
return f"Worker {worker_id}: Job already claimed (rowcount=0)"
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return f"Worker {worker_id}: Error - {e}"
|
||||
|
||||
# Spawn multiple threads to claim simultaneously
|
||||
num_workers = 10
|
||||
print(f"\nSpawning {num_workers} workers to claim job {job_id} simultaneously...")
|
||||
|
||||
# Use a barrier to ensure all threads start at the same time
|
||||
barrier = threading.Barrier(num_workers)
|
||||
|
||||
def worker_with_barrier(worker_id):
|
||||
barrier.wait() # Wait for all threads to be ready
|
||||
return attempt_claim(worker_id)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {executor.submit(worker_with_barrier, i): i for i in range(num_workers)}
|
||||
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
print(f" {result}")
|
||||
|
||||
# Verify results
|
||||
print(f"\n=== Results ===")
|
||||
print(f"Total workers: {num_workers}")
|
||||
print(f"Successful claims: {len(successful_claims)}")
|
||||
print(f"Workers that claimed: {successful_claims}")
|
||||
|
||||
# Check final job status
|
||||
db.session.expire_all()
|
||||
final_job = db.session.get(ProcessingJob, job_id)
|
||||
print(f"Final job status: {final_job.status}")
|
||||
|
||||
# Cleanup
|
||||
db.session.delete(final_job)
|
||||
db.session.delete(test_recording)
|
||||
db.session.commit()
|
||||
|
||||
# Assert only one worker claimed the job
|
||||
assert len(successful_claims) == 1, f"Expected 1 successful claim, got {len(successful_claims)}"
|
||||
assert final_job.status == 'processing', f"Expected status 'processing', got {final_job.status}"
|
||||
|
||||
print("\n[PASS] Only one worker successfully claimed the job!")
|
||||
return True
|
||||
|
||||
|
||||
def test_multiple_jobs_fair_distribution():
|
||||
"""
|
||||
Test that multiple jobs are distributed fairly across workers.
|
||||
"""
|
||||
print("\n=== Testing Multiple Jobs Distribution ===\n")
|
||||
|
||||
from src.app import app
|
||||
from src.database import db
|
||||
from src.models import ProcessingJob, User, Recording
|
||||
from sqlalchemy import update
|
||||
|
||||
with app.app_context():
|
||||
# Use the first existing user for testing
|
||||
test_user = User.query.first()
|
||||
if not test_user:
|
||||
test_user = User(
|
||||
username='test_distribution_user',
|
||||
email='test_dist@example.com',
|
||||
password='not_used'
|
||||
)
|
||||
db.session.add(test_user)
|
||||
db.session.commit()
|
||||
|
||||
# Create multiple test jobs
|
||||
num_jobs = 5
|
||||
job_ids = []
|
||||
recording_ids = []
|
||||
|
||||
for i in range(num_jobs):
|
||||
recording = Recording(
|
||||
user_id=test_user.id,
|
||||
title=f'Test Distribution Recording {i}',
|
||||
audio_path=f'/tmp/test_audio_{i}.mp3',
|
||||
status='QUEUED'
|
||||
)
|
||||
db.session.add(recording)
|
||||
db.session.commit()
|
||||
recording_ids.append(recording.id)
|
||||
|
||||
job = ProcessingJob(
|
||||
recording_id=recording.id,
|
||||
user_id=test_user.id,
|
||||
job_type='transcribe',
|
||||
status='queued'
|
||||
)
|
||||
db.session.add(job)
|
||||
db.session.commit()
|
||||
job_ids.append(job.id)
|
||||
|
||||
print(f"Created {num_jobs} test jobs: {job_ids}")
|
||||
|
||||
# Have workers claim jobs
|
||||
claimed_jobs = []
|
||||
|
||||
def claim_any_job(worker_id):
|
||||
with app.app_context():
|
||||
# Find a queued job
|
||||
candidate = ProcessingJob.query.filter(
|
||||
ProcessingJob.status == 'queued',
|
||||
ProcessingJob.job_type == 'transcribe'
|
||||
).first()
|
||||
|
||||
if not candidate:
|
||||
return None
|
||||
|
||||
# Atomic claim
|
||||
result = db.session.execute(
|
||||
update(ProcessingJob)
|
||||
.where(
|
||||
ProcessingJob.id == candidate.id,
|
||||
ProcessingJob.status == 'queued'
|
||||
)
|
||||
.values(status='processing', started_at=datetime.utcnow())
|
||||
)
|
||||
|
||||
if result.rowcount == 1:
|
||||
db.session.commit()
|
||||
return candidate.id
|
||||
else:
|
||||
db.session.rollback()
|
||||
return None
|
||||
|
||||
# Each "worker" claims one job
|
||||
for i in range(num_jobs + 2): # Extra attempts to ensure no double claims
|
||||
job_id = claim_any_job(i)
|
||||
if job_id:
|
||||
claimed_jobs.append(job_id)
|
||||
print(f" Worker {i} claimed job {job_id}")
|
||||
else:
|
||||
print(f" Worker {i} found no available jobs")
|
||||
|
||||
print(f"\nClaimed jobs: {claimed_jobs}")
|
||||
print(f"Unique jobs claimed: {len(set(claimed_jobs))}")
|
||||
|
||||
# Verify no duplicates
|
||||
assert len(claimed_jobs) == len(set(claimed_jobs)), "Duplicate job claims detected!"
|
||||
assert len(claimed_jobs) == num_jobs, f"Expected {num_jobs} claims, got {len(claimed_jobs)}"
|
||||
|
||||
# Cleanup
|
||||
for job_id in job_ids:
|
||||
job = db.session.get(ProcessingJob, job_id)
|
||||
if job:
|
||||
db.session.delete(job)
|
||||
for rec_id in recording_ids:
|
||||
rec = db.session.get(Recording, rec_id)
|
||||
if rec:
|
||||
db.session.delete(rec)
|
||||
db.session.commit()
|
||||
|
||||
print("\n[PASS] All jobs claimed exactly once!")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=" * 60)
|
||||
print("Job Queue Race Condition Tests")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_atomic_job_claiming()
|
||||
test_multiple_jobs_fair_distribution()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All tests passed!")
|
||||
print("=" * 60)
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n[FAIL] Test failed: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] Unexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user