Files
dictia-public/tests/test_job_queue_race_condition.py

282 lines
9.3 KiB
Python

#!/usr/bin/env python3
"""
Test script for job queue race condition fix.
This script verifies that the atomic job claiming mechanism prevents
multiple workers from claiming the same job simultaneously.
The fix uses an atomic UPDATE with WHERE clause to ensure only one
worker can claim a job, even with multiple processes/threads.
"""
import os
import sys
import threading
import time
from pathlib import Path
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
def test_atomic_job_claiming():
"""
Test that only one worker can claim a job even with concurrent attempts.
This simulates the race condition where multiple workers try to claim
the same job simultaneously.
"""
print("\n=== Testing Atomic Job Claiming ===\n")
# Import Flask app and models
from src.app import app
from src.database import db
from src.models import ProcessingJob, User, Recording
from sqlalchemy import update
with app.app_context():
# Use the first existing user for testing, or create a minimal test user
test_user = User.query.first()
if not test_user:
test_user = User(
username='test_race_condition_user',
email='test_race@example.com',
password='not_used' # Password not needed for this test
)
db.session.add(test_user)
db.session.commit()
# Create a test recording
test_recording = Recording(
user_id=test_user.id,
title='Test Race Condition Recording',
audio_path='/tmp/test_audio.mp3',
status='QUEUED'
)
db.session.add(test_recording)
db.session.commit()
# Create a test job in 'queued' status
test_job = ProcessingJob(
recording_id=test_recording.id,
user_id=test_user.id,
job_type='transcribe',
status='queued'
)
db.session.add(test_job)
db.session.commit()
job_id = test_job.id
print(f"Created test job {job_id} with status 'queued'")
# Track which threads successfully claimed the job
successful_claims = []
claim_lock = threading.Lock()
def attempt_claim(worker_id):
"""Simulate a worker attempting to claim the job."""
with app.app_context():
try:
# This is the atomic claim logic from the fix
claim_time = datetime.utcnow()
result = db.session.execute(
update(ProcessingJob)
.where(
ProcessingJob.id == job_id,
ProcessingJob.status == 'queued'
)
.values(status='processing', started_at=claim_time)
)
if result.rowcount == 1:
db.session.commit()
with claim_lock:
successful_claims.append(worker_id)
return f"Worker {worker_id}: Successfully claimed job"
else:
db.session.rollback()
return f"Worker {worker_id}: Job already claimed (rowcount=0)"
except Exception as e:
db.session.rollback()
return f"Worker {worker_id}: Error - {e}"
# Spawn multiple threads to claim simultaneously
num_workers = 10
print(f"\nSpawning {num_workers} workers to claim job {job_id} simultaneously...")
# Use a barrier to ensure all threads start at the same time
barrier = threading.Barrier(num_workers)
def worker_with_barrier(worker_id):
barrier.wait() # Wait for all threads to be ready
return attempt_claim(worker_id)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(worker_with_barrier, i): i for i in range(num_workers)}
for future in as_completed(futures):
result = future.result()
print(f" {result}")
# Verify results
print(f"\n=== Results ===")
print(f"Total workers: {num_workers}")
print(f"Successful claims: {len(successful_claims)}")
print(f"Workers that claimed: {successful_claims}")
# Check final job status
db.session.expire_all()
final_job = db.session.get(ProcessingJob, job_id)
print(f"Final job status: {final_job.status}")
# Cleanup
db.session.delete(final_job)
db.session.delete(test_recording)
db.session.commit()
# Assert only one worker claimed the job
assert len(successful_claims) == 1, f"Expected 1 successful claim, got {len(successful_claims)}"
assert final_job.status == 'processing', f"Expected status 'processing', got {final_job.status}"
print("\n[PASS] Only one worker successfully claimed the job!")
return True
def test_multiple_jobs_fair_distribution():
"""
Test that multiple jobs are distributed fairly across workers.
"""
print("\n=== Testing Multiple Jobs Distribution ===\n")
from src.app import app
from src.database import db
from src.models import ProcessingJob, User, Recording
from sqlalchemy import update
with app.app_context():
# Use the first existing user for testing
test_user = User.query.first()
if not test_user:
test_user = User(
username='test_distribution_user',
email='test_dist@example.com',
password='not_used'
)
db.session.add(test_user)
db.session.commit()
# Create multiple test jobs
num_jobs = 5
job_ids = []
recording_ids = []
for i in range(num_jobs):
recording = Recording(
user_id=test_user.id,
title=f'Test Distribution Recording {i}',
audio_path=f'/tmp/test_audio_{i}.mp3',
status='QUEUED'
)
db.session.add(recording)
db.session.commit()
recording_ids.append(recording.id)
job = ProcessingJob(
recording_id=recording.id,
user_id=test_user.id,
job_type='transcribe',
status='queued'
)
db.session.add(job)
db.session.commit()
job_ids.append(job.id)
print(f"Created {num_jobs} test jobs: {job_ids}")
# Have workers claim jobs
claimed_jobs = []
def claim_any_job(worker_id):
with app.app_context():
# Find a queued job
candidate = ProcessingJob.query.filter(
ProcessingJob.status == 'queued',
ProcessingJob.job_type == 'transcribe'
).first()
if not candidate:
return None
# Atomic claim
result = db.session.execute(
update(ProcessingJob)
.where(
ProcessingJob.id == candidate.id,
ProcessingJob.status == 'queued'
)
.values(status='processing', started_at=datetime.utcnow())
)
if result.rowcount == 1:
db.session.commit()
return candidate.id
else:
db.session.rollback()
return None
# Each "worker" claims one job
for i in range(num_jobs + 2): # Extra attempts to ensure no double claims
job_id = claim_any_job(i)
if job_id:
claimed_jobs.append(job_id)
print(f" Worker {i} claimed job {job_id}")
else:
print(f" Worker {i} found no available jobs")
print(f"\nClaimed jobs: {claimed_jobs}")
print(f"Unique jobs claimed: {len(set(claimed_jobs))}")
# Verify no duplicates
assert len(claimed_jobs) == len(set(claimed_jobs)), "Duplicate job claims detected!"
assert len(claimed_jobs) == num_jobs, f"Expected {num_jobs} claims, got {len(claimed_jobs)}"
# Cleanup
for job_id in job_ids:
job = db.session.get(ProcessingJob, job_id)
if job:
db.session.delete(job)
for rec_id in recording_ids:
rec = db.session.get(Recording, rec_id)
if rec:
db.session.delete(rec)
db.session.commit()
print("\n[PASS] All jobs claimed exactly once!")
return True
if __name__ == '__main__':
print("=" * 60)
print("Job Queue Race Condition Tests")
print("=" * 60)
try:
test_atomic_job_claiming()
test_multiple_jobs_fair_distribution()
print("\n" + "=" * 60)
print("All tests passed!")
print("=" * 60)
except AssertionError as e:
print(f"\n[FAIL] Test failed: {e}")
sys.exit(1)
except Exception as e:
print(f"\n[ERROR] Unexpected error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)