282 lines
9.3 KiB
Python
282 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for job queue race condition fix.
|
|
|
|
This script verifies that the atomic job claiming mechanism prevents
|
|
multiple workers from claiming the same job simultaneously.
|
|
|
|
The fix uses an atomic UPDATE with WHERE clause to ensure only one
|
|
worker can claim a job, even with multiple processes/threads.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
|
|
def test_atomic_job_claiming():
|
|
"""
|
|
Test that only one worker can claim a job even with concurrent attempts.
|
|
|
|
This simulates the race condition where multiple workers try to claim
|
|
the same job simultaneously.
|
|
"""
|
|
print("\n=== Testing Atomic Job Claiming ===\n")
|
|
|
|
# Import Flask app and models
|
|
from src.app import app
|
|
from src.database import db
|
|
from src.models import ProcessingJob, User, Recording
|
|
from sqlalchemy import update
|
|
|
|
with app.app_context():
|
|
# Use the first existing user for testing, or create a minimal test user
|
|
test_user = User.query.first()
|
|
if not test_user:
|
|
test_user = User(
|
|
username='test_race_condition_user',
|
|
email='test_race@example.com',
|
|
password='not_used' # Password not needed for this test
|
|
)
|
|
db.session.add(test_user)
|
|
db.session.commit()
|
|
|
|
# Create a test recording
|
|
test_recording = Recording(
|
|
user_id=test_user.id,
|
|
title='Test Race Condition Recording',
|
|
audio_path='/tmp/test_audio.mp3',
|
|
status='QUEUED'
|
|
)
|
|
db.session.add(test_recording)
|
|
db.session.commit()
|
|
|
|
# Create a test job in 'queued' status
|
|
test_job = ProcessingJob(
|
|
recording_id=test_recording.id,
|
|
user_id=test_user.id,
|
|
job_type='transcribe',
|
|
status='queued'
|
|
)
|
|
db.session.add(test_job)
|
|
db.session.commit()
|
|
|
|
job_id = test_job.id
|
|
print(f"Created test job {job_id} with status 'queued'")
|
|
|
|
# Track which threads successfully claimed the job
|
|
successful_claims = []
|
|
claim_lock = threading.Lock()
|
|
|
|
def attempt_claim(worker_id):
|
|
"""Simulate a worker attempting to claim the job."""
|
|
with app.app_context():
|
|
try:
|
|
# This is the atomic claim logic from the fix
|
|
claim_time = datetime.utcnow()
|
|
result = db.session.execute(
|
|
update(ProcessingJob)
|
|
.where(
|
|
ProcessingJob.id == job_id,
|
|
ProcessingJob.status == 'queued'
|
|
)
|
|
.values(status='processing', started_at=claim_time)
|
|
)
|
|
|
|
if result.rowcount == 1:
|
|
db.session.commit()
|
|
with claim_lock:
|
|
successful_claims.append(worker_id)
|
|
return f"Worker {worker_id}: Successfully claimed job"
|
|
else:
|
|
db.session.rollback()
|
|
return f"Worker {worker_id}: Job already claimed (rowcount=0)"
|
|
|
|
except Exception as e:
|
|
db.session.rollback()
|
|
return f"Worker {worker_id}: Error - {e}"
|
|
|
|
# Spawn multiple threads to claim simultaneously
|
|
num_workers = 10
|
|
print(f"\nSpawning {num_workers} workers to claim job {job_id} simultaneously...")
|
|
|
|
# Use a barrier to ensure all threads start at the same time
|
|
barrier = threading.Barrier(num_workers)
|
|
|
|
def worker_with_barrier(worker_id):
|
|
barrier.wait() # Wait for all threads to be ready
|
|
return attempt_claim(worker_id)
|
|
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
futures = {executor.submit(worker_with_barrier, i): i for i in range(num_workers)}
|
|
|
|
for future in as_completed(futures):
|
|
result = future.result()
|
|
print(f" {result}")
|
|
|
|
# Verify results
|
|
print(f"\n=== Results ===")
|
|
print(f"Total workers: {num_workers}")
|
|
print(f"Successful claims: {len(successful_claims)}")
|
|
print(f"Workers that claimed: {successful_claims}")
|
|
|
|
# Check final job status
|
|
db.session.expire_all()
|
|
final_job = db.session.get(ProcessingJob, job_id)
|
|
print(f"Final job status: {final_job.status}")
|
|
|
|
# Cleanup
|
|
db.session.delete(final_job)
|
|
db.session.delete(test_recording)
|
|
db.session.commit()
|
|
|
|
# Assert only one worker claimed the job
|
|
assert len(successful_claims) == 1, f"Expected 1 successful claim, got {len(successful_claims)}"
|
|
assert final_job.status == 'processing', f"Expected status 'processing', got {final_job.status}"
|
|
|
|
print("\n[PASS] Only one worker successfully claimed the job!")
|
|
return True
|
|
|
|
|
|
def test_multiple_jobs_fair_distribution():
|
|
"""
|
|
Test that multiple jobs are distributed fairly across workers.
|
|
"""
|
|
print("\n=== Testing Multiple Jobs Distribution ===\n")
|
|
|
|
from src.app import app
|
|
from src.database import db
|
|
from src.models import ProcessingJob, User, Recording
|
|
from sqlalchemy import update
|
|
|
|
with app.app_context():
|
|
# Use the first existing user for testing
|
|
test_user = User.query.first()
|
|
if not test_user:
|
|
test_user = User(
|
|
username='test_distribution_user',
|
|
email='test_dist@example.com',
|
|
password='not_used'
|
|
)
|
|
db.session.add(test_user)
|
|
db.session.commit()
|
|
|
|
# Create multiple test jobs
|
|
num_jobs = 5
|
|
job_ids = []
|
|
recording_ids = []
|
|
|
|
for i in range(num_jobs):
|
|
recording = Recording(
|
|
user_id=test_user.id,
|
|
title=f'Test Distribution Recording {i}',
|
|
audio_path=f'/tmp/test_audio_{i}.mp3',
|
|
status='QUEUED'
|
|
)
|
|
db.session.add(recording)
|
|
db.session.commit()
|
|
recording_ids.append(recording.id)
|
|
|
|
job = ProcessingJob(
|
|
recording_id=recording.id,
|
|
user_id=test_user.id,
|
|
job_type='transcribe',
|
|
status='queued'
|
|
)
|
|
db.session.add(job)
|
|
db.session.commit()
|
|
job_ids.append(job.id)
|
|
|
|
print(f"Created {num_jobs} test jobs: {job_ids}")
|
|
|
|
# Have workers claim jobs
|
|
claimed_jobs = []
|
|
|
|
def claim_any_job(worker_id):
|
|
with app.app_context():
|
|
# Find a queued job
|
|
candidate = ProcessingJob.query.filter(
|
|
ProcessingJob.status == 'queued',
|
|
ProcessingJob.job_type == 'transcribe'
|
|
).first()
|
|
|
|
if not candidate:
|
|
return None
|
|
|
|
# Atomic claim
|
|
result = db.session.execute(
|
|
update(ProcessingJob)
|
|
.where(
|
|
ProcessingJob.id == candidate.id,
|
|
ProcessingJob.status == 'queued'
|
|
)
|
|
.values(status='processing', started_at=datetime.utcnow())
|
|
)
|
|
|
|
if result.rowcount == 1:
|
|
db.session.commit()
|
|
return candidate.id
|
|
else:
|
|
db.session.rollback()
|
|
return None
|
|
|
|
# Each "worker" claims one job
|
|
for i in range(num_jobs + 2): # Extra attempts to ensure no double claims
|
|
job_id = claim_any_job(i)
|
|
if job_id:
|
|
claimed_jobs.append(job_id)
|
|
print(f" Worker {i} claimed job {job_id}")
|
|
else:
|
|
print(f" Worker {i} found no available jobs")
|
|
|
|
print(f"\nClaimed jobs: {claimed_jobs}")
|
|
print(f"Unique jobs claimed: {len(set(claimed_jobs))}")
|
|
|
|
# Verify no duplicates
|
|
assert len(claimed_jobs) == len(set(claimed_jobs)), "Duplicate job claims detected!"
|
|
assert len(claimed_jobs) == num_jobs, f"Expected {num_jobs} claims, got {len(claimed_jobs)}"
|
|
|
|
# Cleanup
|
|
for job_id in job_ids:
|
|
job = db.session.get(ProcessingJob, job_id)
|
|
if job:
|
|
db.session.delete(job)
|
|
for rec_id in recording_ids:
|
|
rec = db.session.get(Recording, rec_id)
|
|
if rec:
|
|
db.session.delete(rec)
|
|
db.session.commit()
|
|
|
|
print("\n[PASS] All jobs claimed exactly once!")
|
|
return True
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print("=" * 60)
|
|
print("Job Queue Race Condition Tests")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
test_atomic_job_claiming()
|
|
test_multiple_jobs_fair_distribution()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("All tests passed!")
|
|
print("=" * 60)
|
|
|
|
except AssertionError as e:
|
|
print(f"\n[FAIL] Test failed: {e}")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"\n[ERROR] Unexpected error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|