Initial release: DictIA v0.8.14-alpha (fork de Speakr, AGPL-3.0)
This commit is contained in:
5
deployment/asr-proxy/.gitignore
vendored
Normal file
5
deployment/asr-proxy/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
gcp-credentials.json
|
||||
usage-stats.json
|
||||
venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
22
deployment/asr-proxy/asr-proxy.service
Normal file
22
deployment/asr-proxy/asr-proxy.service
Normal file
@@ -0,0 +1,22 @@
|
||||
# TEMPLATE — Ne pas copier directement dans /etc/systemd/system/.
|
||||
# Les variables ${ASR_PROXY_USER} et ${ASR_PROXY_DIR} sont des placeholders.
|
||||
# Le fichier service réel est généré par setup.sh (via heredoc bash) avec les
|
||||
# valeurs résolues de $SERVICE_USER et $INSTALL_DIR.
|
||||
# Usage : sudo bash setup.sh (installe et active le service automatiquement)
|
||||
|
||||
[Unit]
|
||||
Description=DictIA ASR Proxy - GPU Auto-Start/Stop for WhisperX
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=${ASR_PROXY_USER}
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
WorkingDirectory=${ASR_PROXY_DIR}
|
||||
ExecStart=${ASR_PROXY_DIR}/venv/bin/python proxy.py
|
||||
Environment=GOOGLE_APPLICATION_CREDENTIALS=${ASR_PROXY_DIR}/gcp-credentials.json
|
||||
Environment=STATS_FILE=${ASR_PROXY_DIR}/usage-stats.json
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
1534
deployment/asr-proxy/dashboard.html
Normal file
1534
deployment/asr-proxy/dashboard.html
Normal file
File diff suppressed because it is too large
Load Diff
741
deployment/asr-proxy/proxy.py
Normal file
741
deployment/asr-proxy/proxy.py
Normal file
@@ -0,0 +1,741 @@
|
||||
"""DictIA ASR Proxy - Auto-start/stop GCP GPU for WhisperX + Ollama.
|
||||
|
||||
Uses Google Cloud Compute REST API directly (no gcloud CLI needed).
|
||||
Proxies both ASR (WhisperX) and LLM (Ollama) requests.
|
||||
Multi-zone fallback across Canada (Montreal + Toronto).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import jwt as pyjwt
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, Response
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
log = logging.getLogger("asr-proxy")
|
||||
|
||||
# Config — paths relative to this script's directory by default
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
GCP_PROJECT = os.getenv("GCP_PROJECT", "speakr-gpu")
|
||||
WHISPERX_PORT = int(os.getenv("WHISPERX_PORT", "9000"))
|
||||
OLLAMA_PORT = int(os.getenv("OLLAMA_PORT", "11434"))
|
||||
IDLE_TIMEOUT = int(os.getenv("IDLE_TIMEOUT", "300"))
|
||||
CREDS_FILE = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", str(SCRIPT_DIR / "gcp-credentials.json"))
|
||||
STATS_FILE = os.getenv("STATS_FILE", str(SCRIPT_DIR / "usage-stats.json"))
|
||||
MONTHLY_LIMIT_HOURS = float(os.getenv("MONTHLY_LIMIT_HOURS", "30"))
|
||||
# Real GCP cost per GPU-hour (g2-standard-4 + L4): GPU ($0.837) + vCPU ($0.151) + RAM ($0.069)
|
||||
GPU_COST_PER_HOUR = float(os.getenv("GPU_COST_PER_HOUR", "1.06"))
|
||||
# Fixed monthly costs: SSD disks ($5.66) + snapshots ($4.19) ≈ $9.85/month
|
||||
FIXED_MONTHLY_COST = float(os.getenv("FIXED_MONTHLY_COST", "9.85"))
|
||||
SNAPSHOT_NAME = "whisperx-gpu-snapshot"
|
||||
HEALTH_POLL_INTERVAL = 5
|
||||
BOOT_TIMEOUT = 300
|
||||
|
||||
# Zone fallback order — Canada only, Montreal first
|
||||
ZONE_FALLBACKS = [
|
||||
{
|
||||
"zone": "northamerica-northeast1-b",
|
||||
"instance": "whisperx-gpu-mtl1",
|
||||
"machine_type": "g2-standard-4",
|
||||
"accelerator": "nvidia-l4",
|
||||
"accel_count": 1,
|
||||
"label": "Montreal-b (L4)",
|
||||
},
|
||||
{
|
||||
"zone": "northamerica-northeast1-c",
|
||||
"instance": "whisperx-gpu-mtl2",
|
||||
"machine_type": "n1-standard-4",
|
||||
"accelerator": "nvidia-tesla-t4",
|
||||
"accel_count": 1,
|
||||
"label": "Montreal-c (T4)",
|
||||
},
|
||||
{
|
||||
"zone": "northamerica-northeast2-a",
|
||||
"instance": "whisperx-gpu-tor1",
|
||||
"machine_type": "g2-standard-4",
|
||||
"accelerator": "nvidia-l4",
|
||||
"accel_count": 1,
|
||||
"label": "Toronto-a (L4)",
|
||||
},
|
||||
{
|
||||
"zone": "northamerica-northeast2-b",
|
||||
"instance": "whisperx-gpu",
|
||||
"machine_type": "g2-standard-4",
|
||||
"accelerator": "nvidia-l4",
|
||||
"accel_count": 1,
|
||||
"label": "Toronto-b (L4)",
|
||||
},
|
||||
]
|
||||
|
||||
STARTUP_SCRIPT = """#!/bin/bash
|
||||
systemctl start docker
|
||||
sleep 5
|
||||
docker start whisperx-asr 2>/dev/null || true
|
||||
systemctl start ollama 2>/dev/null || true
|
||||
"""
|
||||
|
||||
app = FastAPI(title="DictIA ASR Proxy")
|
||||
|
||||
# State
|
||||
last_request_time = 0.0
|
||||
active_requests = 0
|
||||
gpu_ip: str | None = None
|
||||
active_zone: dict | None = None
|
||||
shutdown_task: asyncio.Task | None = None
|
||||
|
||||
# Request history tracking (in-memory, last 20 requests)
|
||||
request_history: list[dict] = []
|
||||
MAX_HISTORY = 20
|
||||
|
||||
# Zone status tracking
|
||||
zone_status: dict[str, dict] = {}
|
||||
|
||||
# Startup lock and failure cooldown
|
||||
_startup_lock: asyncio.Lock | None = None
|
||||
_last_failure_time: float = 0
|
||||
FAILURE_COOLDOWN = 180
|
||||
|
||||
# OAuth2 token cache
|
||||
_access_token: str | None = None
|
||||
_token_expiry: float = 0
|
||||
|
||||
|
||||
# --- Usage Stats ---
|
||||
|
||||
def load_stats() -> dict:
|
||||
try:
|
||||
with open(STATS_FILE) as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {"gpu_seconds": 0, "month": time.strftime("%Y-%m"), "requests": 0, "last_start": 0}
|
||||
|
||||
|
||||
def save_stats(stats: dict):
|
||||
with open(STATS_FILE, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
|
||||
def track_gpu_time():
|
||||
stats = load_stats()
|
||||
current_month = time.strftime("%Y-%m")
|
||||
if stats.get("month") != current_month:
|
||||
stats = {"gpu_seconds": 0, "month": current_month, "requests": 0, "last_start": 0}
|
||||
if stats.get("last_start", 0) > 0:
|
||||
elapsed = time.time() - stats["last_start"]
|
||||
stats["gpu_seconds"] += elapsed
|
||||
stats["last_start"] = 0
|
||||
save_stats(stats)
|
||||
|
||||
|
||||
def check_budget() -> tuple[bool, float]:
|
||||
stats = load_stats()
|
||||
current_month = time.strftime("%Y-%m")
|
||||
if stats.get("month") != current_month:
|
||||
return True, 0.0
|
||||
hours_used = stats.get("gpu_seconds", 0) / 3600
|
||||
return hours_used < MONTHLY_LIMIT_HOURS, hours_used
|
||||
|
||||
|
||||
# --- GCP Auth ---
|
||||
|
||||
async def get_access_token() -> str:
|
||||
global _access_token, _token_expiry
|
||||
if _access_token and time.time() < _token_expiry - 60:
|
||||
return _access_token
|
||||
with open(CREDS_FILE) as f:
|
||||
creds = json.load(f)
|
||||
cred_type = creds.get("type", "authorized_user")
|
||||
async with httpx.AsyncClient() as client:
|
||||
if cred_type == "service_account":
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"iss": creds["client_email"],
|
||||
"scope": "https://www.googleapis.com/auth/compute",
|
||||
"aud": "https://oauth2.googleapis.com/token",
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
}
|
||||
signed = pyjwt.encode(payload, creds["private_key"], algorithm="RS256")
|
||||
resp = await client.post(
|
||||
"https://oauth2.googleapis.com/token",
|
||||
data={
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
"assertion": signed,
|
||||
},
|
||||
)
|
||||
else:
|
||||
resp = await client.post(
|
||||
"https://oauth2.googleapis.com/token",
|
||||
data={
|
||||
"client_id": creds["client_id"],
|
||||
"client_secret": creds["client_secret"],
|
||||
"refresh_token": creds["refresh_token"],
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
_access_token = data["access_token"]
|
||||
_token_expiry = time.time() + data.get("expires_in", 3600)
|
||||
log.info(f"Refreshed GCP access token ({cred_type})")
|
||||
return _access_token
|
||||
|
||||
|
||||
# --- GCP Compute API ---
|
||||
|
||||
COMPUTE_BASE = "https://compute.googleapis.com/compute/v1"
|
||||
|
||||
|
||||
async def gcp_api(method: str, url: str, **kwargs) -> httpx.Response:
|
||||
token = await get_access_token()
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.request(
|
||||
method, url,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
**kwargs,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
async def get_instance_info(zone: str, instance: str) -> dict | None:
|
||||
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}"
|
||||
resp = await gcp_api("GET", url)
|
||||
if resp.status_code == 404:
|
||||
return None
|
||||
if resp.status_code >= 400:
|
||||
log.error(f"GCP API error {resp.status_code}: {resp.text}")
|
||||
return None
|
||||
return resp.json()
|
||||
|
||||
|
||||
def extract_ip(instance_data: dict) -> str:
|
||||
interfaces = instance_data.get("networkInterfaces", [])
|
||||
if interfaces:
|
||||
access = interfaces[0].get("accessConfigs", [])
|
||||
if access:
|
||||
return access[0].get("natIP", "")
|
||||
return ""
|
||||
|
||||
|
||||
async def start_instance_in_zone(zone: str, instance: str) -> bool:
|
||||
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}/start"
|
||||
resp = await gcp_api("POST", url)
|
||||
if resp.status_code < 400:
|
||||
log.info(f"Start requested: {instance} in {zone}")
|
||||
return True
|
||||
log.warning(f"Failed to start {instance} in {zone}: {resp.status_code} {resp.text}")
|
||||
return False
|
||||
|
||||
|
||||
async def stop_instance_in_zone(zone: str, instance: str):
|
||||
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}/stop"
|
||||
resp = await gcp_api("POST", url)
|
||||
if resp.status_code < 400:
|
||||
log.info(f"Stop requested: {instance} in {zone}")
|
||||
else:
|
||||
log.error(f"Failed to stop {instance} in {zone}: {resp.status_code}")
|
||||
|
||||
|
||||
async def create_instance_from_snapshot(config: dict) -> bool:
|
||||
zone = config["zone"]
|
||||
instance = config["instance"]
|
||||
machine = config["machine_type"]
|
||||
accel = config["accelerator"]
|
||||
accel_count = config["accel_count"]
|
||||
|
||||
log.info(f"Creating {instance} in {zone} from snapshot...")
|
||||
|
||||
body = {
|
||||
"name": instance,
|
||||
"machineType": f"zones/{zone}/machineTypes/{machine}",
|
||||
"disks": [{
|
||||
"boot": True,
|
||||
"autoDelete": True,
|
||||
"initializeParams": {
|
||||
"diskSizeGb": "50",
|
||||
"diskType": f"zones/{zone}/diskTypes/pd-ssd",
|
||||
"sourceSnapshot": f"global/snapshots/{SNAPSHOT_NAME}",
|
||||
},
|
||||
}],
|
||||
"networkInterfaces": [{
|
||||
"network": "global/networks/default",
|
||||
"accessConfigs": [{"type": "ONE_TO_ONE_NAT", "name": "External NAT"}],
|
||||
}],
|
||||
"guestAccelerators": [{
|
||||
"acceleratorType": f"zones/{zone}/acceleratorTypes/{accel}",
|
||||
"acceleratorCount": accel_count,
|
||||
}],
|
||||
"scheduling": {
|
||||
"onHostMaintenance": "TERMINATE",
|
||||
"automaticRestart": False,
|
||||
},
|
||||
"tags": {"items": ["whisperx-gpu"]},
|
||||
"metadata": {
|
||||
"items": [{"key": "startup-script", "value": STARTUP_SCRIPT}],
|
||||
},
|
||||
}
|
||||
|
||||
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances"
|
||||
resp = await gcp_api("POST", url, json=body)
|
||||
|
||||
if resp.status_code < 400:
|
||||
log.info(f"Created {instance} in {zone}")
|
||||
return True
|
||||
|
||||
error_text = resp.text
|
||||
if "ZONE_RESOURCE_POOL_EXHAUSTED" in error_text:
|
||||
log.warning(f"No capacity in {zone} -- skipping")
|
||||
elif "QUOTA" in error_text.upper():
|
||||
log.warning(f"Quota exceeded for {zone}: {error_text[:200]}")
|
||||
else:
|
||||
log.error(f"Failed to create in {zone}: {resp.status_code} {error_text[:200]}")
|
||||
return False
|
||||
|
||||
|
||||
# --- Core Logic ---
|
||||
|
||||
async def wait_for_running(zone: str, instance: str, timeout: int = 120, grace: int = 15) -> bool:
|
||||
gone_count = 0
|
||||
start_time = time.time()
|
||||
for _ in range(timeout // 5):
|
||||
info = await get_instance_info(zone, instance)
|
||||
if info and info.get("status") == "RUNNING":
|
||||
return True
|
||||
status = info.get("status", "UNKNOWN") if info else "GONE"
|
||||
elapsed = time.time() - start_time
|
||||
if status == "GONE":
|
||||
gone_count += 1
|
||||
if gone_count >= 2:
|
||||
log.warning(f"{instance} in {zone}: instance disappeared (no capacity)")
|
||||
return False
|
||||
if status in ("STOPPING",):
|
||||
log.warning(f"{instance} in {zone}: status {status} (no capacity)")
|
||||
return False
|
||||
if status in ("TERMINATED", "STOPPED") and elapsed > grace:
|
||||
log.warning(f"{instance} in {zone}: status {status} after {elapsed:.0f}s (no capacity)")
|
||||
return False
|
||||
await asyncio.sleep(5)
|
||||
return False
|
||||
|
||||
|
||||
async def delete_instance(zone: str, instance: str):
|
||||
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}"
|
||||
resp = await gcp_api("DELETE", url)
|
||||
if resp.status_code < 400:
|
||||
log.info(f"Deleted {instance} in {zone} to free quota")
|
||||
elif resp.status_code == 404:
|
||||
pass
|
||||
else:
|
||||
log.warning(f"Failed to delete {instance} in {zone}: {resp.status_code}")
|
||||
|
||||
|
||||
async def ensure_gpu_running() -> str:
|
||||
global gpu_ip, active_zone, _last_failure_time
|
||||
|
||||
if _last_failure_time > 0:
|
||||
remaining = FAILURE_COOLDOWN - (time.time() - _last_failure_time)
|
||||
if remaining > 0:
|
||||
log.info(f"GPU cooldown active ({int(remaining)}s remaining), waiting...")
|
||||
await asyncio.sleep(remaining)
|
||||
_last_failure_time = 0
|
||||
|
||||
async with _startup_lock:
|
||||
ok, hours = check_budget()
|
||||
if not ok:
|
||||
raise RuntimeError(f"Monthly GPU limit reached ({hours:.1f}h / {MONTHLY_LIMIT_HOURS}h)")
|
||||
|
||||
if active_zone:
|
||||
info = await get_instance_info(active_zone["zone"], active_zone["instance"])
|
||||
if info and info.get("status") == "RUNNING":
|
||||
gpu_ip = extract_ip(info)
|
||||
if gpu_ip:
|
||||
return gpu_ip
|
||||
|
||||
errors = []
|
||||
|
||||
for config in ZONE_FALLBACKS:
|
||||
zone = config["zone"]
|
||||
instance = config["instance"]
|
||||
label = config["label"]
|
||||
|
||||
log.info(f"Trying {label}...")
|
||||
info = await get_instance_info(zone, instance)
|
||||
|
||||
if info is None:
|
||||
created = await create_instance_from_snapshot(config)
|
||||
if not created:
|
||||
zone_status[label] = {
|
||||
"status": "no_capacity",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": "no capacity",
|
||||
}
|
||||
errors.append(f"{label}: no capacity")
|
||||
continue
|
||||
if not await wait_for_running(zone, instance, grace=30):
|
||||
zone_status[label] = {
|
||||
"status": "error",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": "created but failed to start",
|
||||
}
|
||||
errors.append(f"{label}: created but failed to start")
|
||||
await delete_instance(zone, instance)
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
else:
|
||||
status = info.get("status", "UNKNOWN")
|
||||
|
||||
if status == "RUNNING":
|
||||
pass
|
||||
elif status in ("TERMINATED", "STOPPED"):
|
||||
zone_status[label] = {
|
||||
"status": "starting",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": None,
|
||||
}
|
||||
started = await start_instance_in_zone(zone, instance)
|
||||
if not started:
|
||||
zone_status[label] = {
|
||||
"status": "error",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": "start rejected",
|
||||
}
|
||||
errors.append(f"{label}: start rejected")
|
||||
continue
|
||||
if not await wait_for_running(zone, instance, grace=20):
|
||||
zone_status[label] = {
|
||||
"status": "error",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": "didn't reach RUNNING",
|
||||
}
|
||||
errors.append(f"{label}: didn't reach RUNNING")
|
||||
continue
|
||||
elif status in ("STAGING", "PROVISIONING"):
|
||||
zone_status[label] = {
|
||||
"status": "starting",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": None,
|
||||
}
|
||||
if not await wait_for_running(zone, instance):
|
||||
zone_status[label] = {
|
||||
"status": "error",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": f"stuck in {status}",
|
||||
}
|
||||
errors.append(f"{label}: stuck in {status}")
|
||||
continue
|
||||
elif status == "STOPPING":
|
||||
log.info(f"{label}: STOPPING, deleting to free quota")
|
||||
await delete_instance(zone, instance)
|
||||
await asyncio.sleep(3)
|
||||
zone_status[label] = {
|
||||
"status": "error",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": "was STOPPING, deleted",
|
||||
}
|
||||
errors.append(f"{label}: was STOPPING, deleted")
|
||||
continue
|
||||
|
||||
info = await get_instance_info(zone, instance)
|
||||
if info and info.get("status") == "RUNNING":
|
||||
gpu_ip = extract_ip(info)
|
||||
if gpu_ip:
|
||||
active_zone = config
|
||||
_last_failure_time = 0
|
||||
zone_status[label] = {
|
||||
"status": "running",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": None,
|
||||
}
|
||||
stats = load_stats()
|
||||
stats["last_start"] = time.time()
|
||||
stats["requests"] = stats.get("requests", 0) + 1
|
||||
stats["active_zone"] = label
|
||||
save_stats(stats)
|
||||
log.info(f"GPU ready in {label}, IP: {gpu_ip}")
|
||||
return gpu_ip
|
||||
|
||||
zone_status[label] = {
|
||||
"status": "error",
|
||||
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"last_error": "running but no IP",
|
||||
}
|
||||
errors.append(f"{label}: running but no IP")
|
||||
|
||||
_last_failure_time = time.time()
|
||||
raise RuntimeError(
|
||||
f"No GPU available in any Canadian zone. Tried: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
|
||||
async def ensure_gpu_ready() -> str:
|
||||
ip = await ensure_gpu_running()
|
||||
url = f"http://{ip}:{WHISPERX_PORT}/health"
|
||||
log.info(f"Waiting for WhisperX at {url}...")
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
for _ in range(BOOT_TIMEOUT // HEALTH_POLL_INTERVAL):
|
||||
try:
|
||||
resp = await client.get(url)
|
||||
if resp.status_code == 200:
|
||||
log.info("WhisperX is healthy!")
|
||||
return ip
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout):
|
||||
pass
|
||||
await asyncio.sleep(HEALTH_POLL_INTERVAL)
|
||||
raise RuntimeError("WhisperX did not become healthy in time")
|
||||
|
||||
|
||||
async def ensure_ollama_ready() -> str:
|
||||
ip = await ensure_gpu_running()
|
||||
url = f"http://{ip}:{OLLAMA_PORT}/api/tags"
|
||||
log.info(f"Waiting for Ollama at {url}...")
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
for _ in range(BOOT_TIMEOUT // HEALTH_POLL_INTERVAL):
|
||||
try:
|
||||
resp = await client.get(url)
|
||||
if resp.status_code == 200:
|
||||
log.info("Ollama is healthy!")
|
||||
return ip
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout):
|
||||
pass
|
||||
await asyncio.sleep(HEALTH_POLL_INTERVAL)
|
||||
raise RuntimeError("Ollama did not become healthy in time")
|
||||
|
||||
|
||||
async def idle_shutdown_loop():
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
if last_request_time == 0 or active_zone is None:
|
||||
continue
|
||||
if active_requests > 0:
|
||||
continue
|
||||
elapsed = time.time() - last_request_time
|
||||
if elapsed >= IDLE_TIMEOUT:
|
||||
try:
|
||||
zone = active_zone["zone"]
|
||||
instance = active_zone["instance"]
|
||||
label = active_zone["label"]
|
||||
info = await get_instance_info(zone, instance)
|
||||
if info and info.get("status") == "RUNNING":
|
||||
log.info(f"Idle {int(elapsed)}s -- stopping {label}")
|
||||
await stop_instance_in_zone(zone, instance)
|
||||
track_gpu_time()
|
||||
except Exception as e:
|
||||
log.error(f"Error stopping: {e}")
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
global shutdown_task, _startup_lock
|
||||
_startup_lock = asyncio.Lock()
|
||||
await get_access_token()
|
||||
shutdown_task = asyncio.create_task(idle_shutdown_loop())
|
||||
zones = ", ".join(c["label"] for c in ZONE_FALLBACKS)
|
||||
log.info(f"DictIA ASR Proxy started. Zones: [{zones}]. Idle: {IDLE_TIMEOUT}s, limit: {MONTHLY_LIMIT_HOURS}h")
|
||||
|
||||
|
||||
@app.post("/asr")
|
||||
async def asr_proxy(request: Request):
|
||||
global last_request_time, active_requests
|
||||
|
||||
body = await request.body()
|
||||
headers = {
|
||||
k: v for k, v in request.headers.items()
|
||||
if k.lower() not in ("host", "transfer-encoding")
|
||||
}
|
||||
|
||||
last_request_time = time.time()
|
||||
active_requests += 1
|
||||
start_time = time.time()
|
||||
result_status = 200
|
||||
try:
|
||||
ip = await ensure_gpu_ready()
|
||||
target = f"http://{ip}:{WHISPERX_PORT}/asr"
|
||||
log.info(f"Forwarding {len(body)} bytes to {target}")
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(7200.0)) as client:
|
||||
resp = await client.post(target, content=body, headers=headers)
|
||||
last_request_time = time.time()
|
||||
result_status = resp.status_code
|
||||
ct = resp.headers.get("content-type", "")
|
||||
if "application/json" in ct:
|
||||
return JSONResponse(content=resp.json(), status_code=resp.status_code)
|
||||
else:
|
||||
return JSONResponse(content=resp.text, status_code=resp.status_code)
|
||||
except httpx.ReadTimeout:
|
||||
result_status = 504
|
||||
return JSONResponse({"error": "Transcription timeout (2h)"}, status_code=504)
|
||||
except Exception as e:
|
||||
result_status = 502
|
||||
log.error(f"Proxy error: {e}")
|
||||
return JSONResponse({"error": str(e)}, status_code=502)
|
||||
finally:
|
||||
active_requests -= 1
|
||||
last_request_time = time.time()
|
||||
request_history.insert(0, {
|
||||
"time": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"type": "ASR",
|
||||
"duration_sec": round(time.time() - start_time, 1),
|
||||
"status": result_status,
|
||||
"zone": active_zone["label"] if active_zone else "none",
|
||||
})
|
||||
if len(request_history) > MAX_HISTORY:
|
||||
request_history.pop()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
zone_label = active_zone["label"] if active_zone else "none"
|
||||
gpu_status = "unknown"
|
||||
if active_zone:
|
||||
try:
|
||||
info = await get_instance_info(active_zone["zone"], active_zone["instance"])
|
||||
gpu_status = info.get("status", "unknown") if info else "not_found"
|
||||
except Exception:
|
||||
pass
|
||||
ok, hours = check_budget()
|
||||
stats = load_stats()
|
||||
return {
|
||||
"proxy": "healthy",
|
||||
"gpu_instance": gpu_status,
|
||||
"gpu_zone": zone_label,
|
||||
"active_requests": active_requests,
|
||||
"idle_timeout": IDLE_TIMEOUT,
|
||||
"usage": {
|
||||
"month": stats.get("month"),
|
||||
"gpu_hours": round(hours, 2),
|
||||
"gpu_limit_hours": MONTHLY_LIMIT_HOURS,
|
||||
"requests_count": stats.get("requests", 0),
|
||||
"budget_ok": ok,
|
||||
},
|
||||
"gpu_ip": gpu_ip,
|
||||
"machine_type": active_zone.get("machine_type", "unknown") if active_zone else "unknown",
|
||||
"gpu_model": active_zone.get("accelerator", "unknown") if active_zone else "unknown",
|
||||
"idle_seconds": round(time.time() - last_request_time) if last_request_time > 0 else 0,
|
||||
"auto_shutdown_in": max(0, IDLE_TIMEOUT - round(time.time() - last_request_time)) if last_request_time > 0 and active_zone else None,
|
||||
"token_expires_in": round(_token_expiry - time.time()) if _token_expiry > 0 else None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
stats = load_stats()
|
||||
hours = stats.get("gpu_seconds", 0) / 3600
|
||||
gpu_cost = hours * GPU_COST_PER_HOUR
|
||||
total_cost = gpu_cost + FIXED_MONTHLY_COST
|
||||
return {
|
||||
"month": stats.get("month"),
|
||||
"gpu_hours": round(hours, 2),
|
||||
"gpu_minutes": round(hours * 60, 1),
|
||||
"estimated_cost_usd": round(total_cost, 2),
|
||||
"gpu_cost_usd": round(gpu_cost, 2),
|
||||
"fixed_cost_usd": FIXED_MONTHLY_COST,
|
||||
"monthly_limit_hours": MONTHLY_LIMIT_HOURS,
|
||||
"remaining_hours": round(MONTHLY_LIMIT_HOURS - hours, 2),
|
||||
"requests_count": stats.get("requests", 0),
|
||||
"active_zone": stats.get("active_zone", "none"),
|
||||
"cost_per_hour": GPU_COST_PER_HOUR,
|
||||
"recent_requests": request_history[:10],
|
||||
"zone_fallbacks": [
|
||||
{
|
||||
"label": config["label"],
|
||||
"zone": config["zone"],
|
||||
"machine": config["machine_type"],
|
||||
"gpu": config["accelerator"],
|
||||
**zone_status.get(config["label"], {"status": "unknown", "last_tried": None, "last_error": None}),
|
||||
}
|
||||
for config in ZONE_FALLBACKS
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/gpu/start")
|
||||
async def gpu_start():
|
||||
try:
|
||||
ip = await ensure_gpu_ready()
|
||||
label = active_zone["label"] if active_zone else "unknown"
|
||||
return {"status": "running", "ip": ip, "zone": label}
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": str(e)}, status_code=503)
|
||||
|
||||
|
||||
@app.post("/gpu/stop")
|
||||
async def gpu_stop():
|
||||
if not active_zone:
|
||||
return {"status": "no active instance"}
|
||||
try:
|
||||
await stop_instance_in_zone(active_zone["zone"], active_zone["instance"])
|
||||
track_gpu_time()
|
||||
return {"status": "stopped", "zone": active_zone["label"]}
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
DASHBOARD_HTML = Path(__file__).parent / "dashboard.html"
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def dashboard():
|
||||
if DASHBOARD_HTML.exists():
|
||||
return HTMLResponse(DASHBOARD_HTML.read_text(encoding="utf-8"))
|
||||
return HTMLResponse("<h1>Dashboard not found</h1><p>Place dashboard.html next to proxy.py</p>", status_code=404)
|
||||
|
||||
|
||||
@app.api_route("/v1/{path:path}", methods=["POST", "GET"])
|
||||
async def llm_proxy(request: Request, path: str):
|
||||
global last_request_time, active_requests
|
||||
|
||||
body = await request.body()
|
||||
headers = {
|
||||
k: v for k, v in request.headers.items()
|
||||
if k.lower() not in ("host", "transfer-encoding")
|
||||
}
|
||||
|
||||
last_request_time = time.time()
|
||||
active_requests += 1
|
||||
start_time = time.time()
|
||||
result_status = 200
|
||||
try:
|
||||
ip = await ensure_ollama_ready()
|
||||
target = f"http://{ip}:{OLLAMA_PORT}/v1/{path}"
|
||||
log.info(f"Forwarding LLM request to {target}")
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||
resp = await client.request(request.method, target, content=body, headers=headers)
|
||||
last_request_time = time.time()
|
||||
result_status = resp.status_code
|
||||
return Response(
|
||||
content=resp.content,
|
||||
status_code=resp.status_code,
|
||||
media_type=resp.headers.get("content-type"),
|
||||
)
|
||||
except httpx.ReadTimeout:
|
||||
result_status = 504
|
||||
return JSONResponse({"error": "LLM timeout (5min)"}, status_code=504)
|
||||
except Exception as e:
|
||||
result_status = 502
|
||||
log.error(f"LLM proxy error: {e}")
|
||||
return JSONResponse({"error": str(e)}, status_code=502)
|
||||
finally:
|
||||
active_requests -= 1
|
||||
last_request_time = time.time()
|
||||
request_history.insert(0, {
|
||||
"time": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"type": "LLM",
|
||||
"duration_sec": round(time.time() - start_time, 1),
|
||||
"status": result_status,
|
||||
"zone": active_zone["label"] if active_zone else "none",
|
||||
})
|
||||
if len(request_history) > MAX_HISTORY:
|
||||
request_history.pop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=9090)
|
||||
5
deployment/asr-proxy/requirements.txt
Normal file
5
deployment/asr-proxy/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
fastapi==0.115.0
|
||||
uvicorn==0.30.0
|
||||
httpx==0.27.0
|
||||
PyJWT==2.9.0
|
||||
cryptography>=43.0.0
|
||||
87
deployment/asr-proxy/setup.sh
Normal file
87
deployment/asr-proxy/setup.sh
Normal file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env bash
|
||||
# DictIA ASR Proxy — Setup script
|
||||
# Installs the GCP GPU proxy for cloud deployments.
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
INSTALL_DIR="${ASR_PROXY_DIR:-$SCRIPT_DIR}"
|
||||
SERVICE_USER="${ASR_PROXY_USER:-$(whoami)}"
|
||||
|
||||
echo "=== DictIA ASR Proxy Setup ==="
|
||||
echo "Install directory: $INSTALL_DIR"
|
||||
echo "Service user: $SERVICE_USER"
|
||||
echo
|
||||
|
||||
# 1. Create virtual environment
|
||||
if [ ! -d "$INSTALL_DIR/venv" ]; then
|
||||
echo "[1/4] Creating Python virtual environment..."
|
||||
python3 -m venv "$INSTALL_DIR/venv"
|
||||
else
|
||||
echo "[1/4] Virtual environment already exists."
|
||||
fi
|
||||
|
||||
# 2. Install dependencies
|
||||
echo "[2/4] Installing Python dependencies..."
|
||||
"$INSTALL_DIR/venv/bin/pip" install --quiet --upgrade pip
|
||||
"$INSTALL_DIR/venv/bin/pip" install --quiet -r "$INSTALL_DIR/requirements.txt"
|
||||
|
||||
# 3. GCP credentials
|
||||
if [ ! -f "$INSTALL_DIR/gcp-credentials.json" ]; then
|
||||
echo "[3/4] GCP credentials not found."
|
||||
echo " Place your GCP service account or OAuth credentials at:"
|
||||
echo " $INSTALL_DIR/gcp-credentials.json"
|
||||
echo
|
||||
echo " For service account: download JSON from GCP Console > IAM > Service Accounts"
|
||||
echo " For user credentials: run 'gcloud auth application-default login' and copy the file"
|
||||
echo
|
||||
read -rp " Path to credentials file (or press Enter to skip): " CREDS_PATH
|
||||
if [ -n "$CREDS_PATH" ] && [ -f "$CREDS_PATH" ]; then
|
||||
cp "$CREDS_PATH" "$INSTALL_DIR/gcp-credentials.json"
|
||||
chmod 600 "$INSTALL_DIR/gcp-credentials.json"
|
||||
echo " Credentials copied."
|
||||
else
|
||||
echo " Skipped. You must add credentials before starting the proxy."
|
||||
fi
|
||||
else
|
||||
echo "[3/4] GCP credentials found."
|
||||
fi
|
||||
|
||||
# 4. Install systemd service
|
||||
echo "[4/4] Installing systemd service..."
|
||||
SERVICE_FILE="/etc/systemd/system/asr-proxy.service"
|
||||
|
||||
cat > /tmp/asr-proxy.service <<UNIT
|
||||
[Unit]
|
||||
Description=DictIA ASR Proxy - GPU Auto-Start/Stop for WhisperX
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=$SERVICE_USER
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
WorkingDirectory=$INSTALL_DIR
|
||||
ExecStart=$INSTALL_DIR/venv/bin/python proxy.py
|
||||
Environment=GOOGLE_APPLICATION_CREDENTIALS=$INSTALL_DIR/gcp-credentials.json
|
||||
Environment=STATS_FILE=$INSTALL_DIR/usage-stats.json
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
UNIT
|
||||
|
||||
if [ "$(id -u)" -eq 0 ]; then
|
||||
cp /tmp/asr-proxy.service "$SERVICE_FILE"
|
||||
systemctl daemon-reload
|
||||
systemctl enable asr-proxy.service
|
||||
echo " Service installed and enabled."
|
||||
echo " Start with: systemctl start asr-proxy"
|
||||
else
|
||||
echo " Run as root to install systemd service, or copy manually:"
|
||||
echo " sudo cp /tmp/asr-proxy.service $SERVICE_FILE"
|
||||
echo " sudo systemctl daemon-reload && sudo systemctl enable asr-proxy"
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "=== Setup complete ==="
|
||||
echo "Dashboard: http://localhost:9090"
|
||||
echo "Health: http://localhost:9090/health"
|
||||
Reference in New Issue
Block a user