"""DictIA ASR Proxy - Auto-start/stop GCP GPU for WhisperX + Ollama. Uses Google Cloud Compute REST API directly (no gcloud CLI needed). Proxies both ASR (WhisperX) and LLM (Ollama) requests. Multi-zone fallback across Canada (Montreal + Toronto). """ import asyncio import json import logging import os import time import httpx import jwt as pyjwt from pathlib import Path from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse, Response logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") log = logging.getLogger("asr-proxy") # Config — paths relative to this script's directory by default SCRIPT_DIR = Path(__file__).parent GCP_PROJECT = os.getenv("GCP_PROJECT", "speakr-gpu") WHISPERX_PORT = int(os.getenv("WHISPERX_PORT", "9000")) OLLAMA_PORT = int(os.getenv("OLLAMA_PORT", "11434")) IDLE_TIMEOUT = int(os.getenv("IDLE_TIMEOUT", "300")) CREDS_FILE = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", str(SCRIPT_DIR / "gcp-credentials.json")) STATS_FILE = os.getenv("STATS_FILE", str(SCRIPT_DIR / "usage-stats.json")) MONTHLY_LIMIT_HOURS = float(os.getenv("MONTHLY_LIMIT_HOURS", "30")) # Real GCP cost per GPU-hour (g2-standard-4 + L4): GPU ($0.837) + vCPU ($0.151) + RAM ($0.069) GPU_COST_PER_HOUR = float(os.getenv("GPU_COST_PER_HOUR", "1.06")) # Fixed monthly costs: SSD disks ($5.66) + snapshots ($4.19) ≈ $9.85/month FIXED_MONTHLY_COST = float(os.getenv("FIXED_MONTHLY_COST", "9.85")) SNAPSHOT_NAME = "whisperx-gpu-snapshot" HEALTH_POLL_INTERVAL = 5 BOOT_TIMEOUT = 300 # Zone fallback order — Canada only, Montreal first ZONE_FALLBACKS = [ { "zone": "northamerica-northeast1-b", "instance": "whisperx-gpu-mtl1", "machine_type": "g2-standard-4", "accelerator": "nvidia-l4", "accel_count": 1, "label": "Montreal-b (L4)", }, { "zone": "northamerica-northeast1-c", "instance": "whisperx-gpu-mtl2", "machine_type": "n1-standard-4", "accelerator": "nvidia-tesla-t4", "accel_count": 1, "label": "Montreal-c (T4)", }, { "zone": "northamerica-northeast2-a", "instance": "whisperx-gpu-tor1", "machine_type": "g2-standard-4", "accelerator": "nvidia-l4", "accel_count": 1, "label": "Toronto-a (L4)", }, { "zone": "northamerica-northeast2-b", "instance": "whisperx-gpu", "machine_type": "g2-standard-4", "accelerator": "nvidia-l4", "accel_count": 1, "label": "Toronto-b (L4)", }, ] STARTUP_SCRIPT = """#!/bin/bash systemctl start docker sleep 5 docker start whisperx-asr 2>/dev/null || true systemctl start ollama 2>/dev/null || true """ app = FastAPI(title="DictIA ASR Proxy") # State last_request_time = 0.0 active_requests = 0 gpu_ip: str | None = None active_zone: dict | None = None shutdown_task: asyncio.Task | None = None # Request history tracking (in-memory, last 20 requests) request_history: list[dict] = [] MAX_HISTORY = 20 # Zone status tracking zone_status: dict[str, dict] = {} # Startup lock and failure cooldown _startup_lock: asyncio.Lock | None = None _last_failure_time: float = 0 FAILURE_COOLDOWN = 180 # OAuth2 token cache _access_token: str | None = None _token_expiry: float = 0 # --- Usage Stats --- def load_stats() -> dict: try: with open(STATS_FILE) as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): return {"gpu_seconds": 0, "month": time.strftime("%Y-%m"), "requests": 0, "last_start": 0} def save_stats(stats: dict): with open(STATS_FILE, "w") as f: json.dump(stats, f, indent=2) def track_gpu_time(): stats = load_stats() current_month = time.strftime("%Y-%m") if stats.get("month") != current_month: stats = {"gpu_seconds": 0, "month": current_month, "requests": 0, "last_start": 0} if stats.get("last_start", 0) > 0: elapsed = time.time() - stats["last_start"] stats["gpu_seconds"] += elapsed stats["last_start"] = 0 save_stats(stats) def check_budget() -> tuple[bool, float]: stats = load_stats() current_month = time.strftime("%Y-%m") if stats.get("month") != current_month: return True, 0.0 hours_used = stats.get("gpu_seconds", 0) / 3600 return hours_used < MONTHLY_LIMIT_HOURS, hours_used # --- GCP Auth --- async def get_access_token() -> str: global _access_token, _token_expiry if _access_token and time.time() < _token_expiry - 60: return _access_token with open(CREDS_FILE) as f: creds = json.load(f) cred_type = creds.get("type", "authorized_user") async with httpx.AsyncClient() as client: if cred_type == "service_account": now = int(time.time()) payload = { "iss": creds["client_email"], "scope": "https://www.googleapis.com/auth/compute", "aud": "https://oauth2.googleapis.com/token", "iat": now, "exp": now + 3600, } signed = pyjwt.encode(payload, creds["private_key"], algorithm="RS256") resp = await client.post( "https://oauth2.googleapis.com/token", data={ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": signed, }, ) else: resp = await client.post( "https://oauth2.googleapis.com/token", data={ "client_id": creds["client_id"], "client_secret": creds["client_secret"], "refresh_token": creds["refresh_token"], "grant_type": "refresh_token", }, ) resp.raise_for_status() data = resp.json() _access_token = data["access_token"] _token_expiry = time.time() + data.get("expires_in", 3600) log.info(f"Refreshed GCP access token ({cred_type})") return _access_token # --- GCP Compute API --- COMPUTE_BASE = "https://compute.googleapis.com/compute/v1" async def gcp_api(method: str, url: str, **kwargs) -> httpx.Response: token = await get_access_token() async with httpx.AsyncClient(timeout=60) as client: resp = await client.request( method, url, headers={"Authorization": f"Bearer {token}"}, **kwargs, ) return resp async def get_instance_info(zone: str, instance: str) -> dict | None: url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}" resp = await gcp_api("GET", url) if resp.status_code == 404: return None if resp.status_code >= 400: log.error(f"GCP API error {resp.status_code}: {resp.text}") return None return resp.json() def extract_ip(instance_data: dict) -> str: interfaces = instance_data.get("networkInterfaces", []) if interfaces: access = interfaces[0].get("accessConfigs", []) if access: return access[0].get("natIP", "") return "" async def start_instance_in_zone(zone: str, instance: str) -> bool: url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}/start" resp = await gcp_api("POST", url) if resp.status_code < 400: log.info(f"Start requested: {instance} in {zone}") return True log.warning(f"Failed to start {instance} in {zone}: {resp.status_code} {resp.text}") return False async def stop_instance_in_zone(zone: str, instance: str): url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}/stop" resp = await gcp_api("POST", url) if resp.status_code < 400: log.info(f"Stop requested: {instance} in {zone}") else: log.error(f"Failed to stop {instance} in {zone}: {resp.status_code}") async def create_instance_from_snapshot(config: dict) -> bool: zone = config["zone"] instance = config["instance"] machine = config["machine_type"] accel = config["accelerator"] accel_count = config["accel_count"] log.info(f"Creating {instance} in {zone} from snapshot...") body = { "name": instance, "machineType": f"zones/{zone}/machineTypes/{machine}", "disks": [{ "boot": True, "autoDelete": True, "initializeParams": { "diskSizeGb": "50", "diskType": f"zones/{zone}/diskTypes/pd-ssd", "sourceSnapshot": f"global/snapshots/{SNAPSHOT_NAME}", }, }], "networkInterfaces": [{ "network": "global/networks/default", "accessConfigs": [{"type": "ONE_TO_ONE_NAT", "name": "External NAT"}], }], "guestAccelerators": [{ "acceleratorType": f"zones/{zone}/acceleratorTypes/{accel}", "acceleratorCount": accel_count, }], "scheduling": { "onHostMaintenance": "TERMINATE", "automaticRestart": False, }, "tags": {"items": ["whisperx-gpu"]}, "metadata": { "items": [{"key": "startup-script", "value": STARTUP_SCRIPT}], }, } url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances" resp = await gcp_api("POST", url, json=body) if resp.status_code < 400: log.info(f"Created {instance} in {zone}") return True error_text = resp.text if "ZONE_RESOURCE_POOL_EXHAUSTED" in error_text: log.warning(f"No capacity in {zone} -- skipping") elif "QUOTA" in error_text.upper(): log.warning(f"Quota exceeded for {zone}: {error_text[:200]}") else: log.error(f"Failed to create in {zone}: {resp.status_code} {error_text[:200]}") return False # --- Core Logic --- async def wait_for_running(zone: str, instance: str, timeout: int = 120, grace: int = 15) -> bool: gone_count = 0 start_time = time.time() for _ in range(timeout // 5): info = await get_instance_info(zone, instance) if info and info.get("status") == "RUNNING": return True status = info.get("status", "UNKNOWN") if info else "GONE" elapsed = time.time() - start_time if status == "GONE": gone_count += 1 if gone_count >= 2: log.warning(f"{instance} in {zone}: instance disappeared (no capacity)") return False if status in ("STOPPING",): log.warning(f"{instance} in {zone}: status {status} (no capacity)") return False if status in ("TERMINATED", "STOPPED") and elapsed > grace: log.warning(f"{instance} in {zone}: status {status} after {elapsed:.0f}s (no capacity)") return False await asyncio.sleep(5) return False async def delete_instance(zone: str, instance: str): url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}" resp = await gcp_api("DELETE", url) if resp.status_code < 400: log.info(f"Deleted {instance} in {zone} to free quota") elif resp.status_code == 404: pass else: log.warning(f"Failed to delete {instance} in {zone}: {resp.status_code}") async def ensure_gpu_running() -> str: global gpu_ip, active_zone, _last_failure_time if _last_failure_time > 0: remaining = FAILURE_COOLDOWN - (time.time() - _last_failure_time) if remaining > 0: log.info(f"GPU cooldown active ({int(remaining)}s remaining), waiting...") await asyncio.sleep(remaining) _last_failure_time = 0 async with _startup_lock: ok, hours = check_budget() if not ok: raise RuntimeError(f"Monthly GPU limit reached ({hours:.1f}h / {MONTHLY_LIMIT_HOURS}h)") if active_zone: info = await get_instance_info(active_zone["zone"], active_zone["instance"]) if info and info.get("status") == "RUNNING": gpu_ip = extract_ip(info) if gpu_ip: return gpu_ip errors = [] for config in ZONE_FALLBACKS: zone = config["zone"] instance = config["instance"] label = config["label"] log.info(f"Trying {label}...") info = await get_instance_info(zone, instance) if info is None: created = await create_instance_from_snapshot(config) if not created: zone_status[label] = { "status": "no_capacity", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": "no capacity", } errors.append(f"{label}: no capacity") continue if not await wait_for_running(zone, instance, grace=30): zone_status[label] = { "status": "error", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": "created but failed to start", } errors.append(f"{label}: created but failed to start") await delete_instance(zone, instance) await asyncio.sleep(3) continue else: status = info.get("status", "UNKNOWN") if status == "RUNNING": pass elif status in ("TERMINATED", "STOPPED"): zone_status[label] = { "status": "starting", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": None, } started = await start_instance_in_zone(zone, instance) if not started: zone_status[label] = { "status": "error", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": "start rejected", } errors.append(f"{label}: start rejected") continue if not await wait_for_running(zone, instance, grace=20): zone_status[label] = { "status": "error", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": "didn't reach RUNNING", } errors.append(f"{label}: didn't reach RUNNING") continue elif status in ("STAGING", "PROVISIONING"): zone_status[label] = { "status": "starting", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": None, } if not await wait_for_running(zone, instance): zone_status[label] = { "status": "error", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": f"stuck in {status}", } errors.append(f"{label}: stuck in {status}") continue elif status == "STOPPING": log.info(f"{label}: STOPPING, deleting to free quota") await delete_instance(zone, instance) await asyncio.sleep(3) zone_status[label] = { "status": "error", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": "was STOPPING, deleted", } errors.append(f"{label}: was STOPPING, deleted") continue info = await get_instance_info(zone, instance) if info and info.get("status") == "RUNNING": gpu_ip = extract_ip(info) if gpu_ip: active_zone = config _last_failure_time = 0 zone_status[label] = { "status": "running", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": None, } stats = load_stats() stats["last_start"] = time.time() stats["requests"] = stats.get("requests", 0) + 1 stats["active_zone"] = label save_stats(stats) log.info(f"GPU ready in {label}, IP: {gpu_ip}") return gpu_ip zone_status[label] = { "status": "error", "last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"), "last_error": "running but no IP", } errors.append(f"{label}: running but no IP") _last_failure_time = time.time() raise RuntimeError( f"No GPU available in any Canadian zone. Tried: {'; '.join(errors)}" ) async def ensure_gpu_ready() -> str: ip = await ensure_gpu_running() url = f"http://{ip}:{WHISPERX_PORT}/health" log.info(f"Waiting for WhisperX at {url}...") async with httpx.AsyncClient(timeout=10) as client: for _ in range(BOOT_TIMEOUT // HEALTH_POLL_INTERVAL): try: resp = await client.get(url) if resp.status_code == 200: log.info("WhisperX is healthy!") return ip except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout): pass await asyncio.sleep(HEALTH_POLL_INTERVAL) raise RuntimeError("WhisperX did not become healthy in time") async def ensure_ollama_ready() -> str: ip = await ensure_gpu_running() url = f"http://{ip}:{OLLAMA_PORT}/api/tags" log.info(f"Waiting for Ollama at {url}...") async with httpx.AsyncClient(timeout=10) as client: for _ in range(BOOT_TIMEOUT // HEALTH_POLL_INTERVAL): try: resp = await client.get(url) if resp.status_code == 200: log.info("Ollama is healthy!") return ip except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout): pass await asyncio.sleep(HEALTH_POLL_INTERVAL) raise RuntimeError("Ollama did not become healthy in time") async def idle_shutdown_loop(): while True: await asyncio.sleep(60) if last_request_time == 0 or active_zone is None: continue if active_requests > 0: continue elapsed = time.time() - last_request_time if elapsed >= IDLE_TIMEOUT: try: zone = active_zone["zone"] instance = active_zone["instance"] label = active_zone["label"] info = await get_instance_info(zone, instance) if info and info.get("status") == "RUNNING": log.info(f"Idle {int(elapsed)}s -- stopping {label}") await stop_instance_in_zone(zone, instance) track_gpu_time() except Exception as e: log.error(f"Error stopping: {e}") # --- Endpoints --- @app.on_event("startup") async def on_startup(): global shutdown_task, _startup_lock _startup_lock = asyncio.Lock() await get_access_token() shutdown_task = asyncio.create_task(idle_shutdown_loop()) zones = ", ".join(c["label"] for c in ZONE_FALLBACKS) log.info(f"DictIA ASR Proxy started. Zones: [{zones}]. Idle: {IDLE_TIMEOUT}s, limit: {MONTHLY_LIMIT_HOURS}h") @app.post("/asr") async def asr_proxy(request: Request): global last_request_time, active_requests body = await request.body() headers = { k: v for k, v in request.headers.items() if k.lower() not in ("host", "transfer-encoding") } last_request_time = time.time() active_requests += 1 start_time = time.time() result_status = 200 try: ip = await ensure_gpu_ready() target = f"http://{ip}:{WHISPERX_PORT}/asr" log.info(f"Forwarding {len(body)} bytes to {target}") async with httpx.AsyncClient(timeout=httpx.Timeout(7200.0)) as client: resp = await client.post(target, content=body, headers=headers) last_request_time = time.time() result_status = resp.status_code ct = resp.headers.get("content-type", "") if "application/json" in ct: return JSONResponse(content=resp.json(), status_code=resp.status_code) else: return JSONResponse(content=resp.text, status_code=resp.status_code) except httpx.ReadTimeout: result_status = 504 return JSONResponse({"error": "Transcription timeout (2h)"}, status_code=504) except Exception as e: result_status = 502 log.error(f"Proxy error: {e}") return JSONResponse({"error": str(e)}, status_code=502) finally: active_requests -= 1 last_request_time = time.time() request_history.insert(0, { "time": time.strftime("%Y-%m-%dT%H:%M:%S"), "type": "ASR", "duration_sec": round(time.time() - start_time, 1), "status": result_status, "zone": active_zone["label"] if active_zone else "none", }) if len(request_history) > MAX_HISTORY: request_history.pop() @app.get("/health") async def health(): zone_label = active_zone["label"] if active_zone else "none" gpu_status = "unknown" if active_zone: try: info = await get_instance_info(active_zone["zone"], active_zone["instance"]) gpu_status = info.get("status", "unknown") if info else "not_found" except Exception: pass ok, hours = check_budget() stats = load_stats() return { "proxy": "healthy", "gpu_instance": gpu_status, "gpu_zone": zone_label, "active_requests": active_requests, "idle_timeout": IDLE_TIMEOUT, "usage": { "month": stats.get("month"), "gpu_hours": round(hours, 2), "gpu_limit_hours": MONTHLY_LIMIT_HOURS, "requests_count": stats.get("requests", 0), "budget_ok": ok, }, "gpu_ip": gpu_ip, "machine_type": active_zone.get("machine_type", "unknown") if active_zone else "unknown", "gpu_model": active_zone.get("accelerator", "unknown") if active_zone else "unknown", "idle_seconds": round(time.time() - last_request_time) if last_request_time > 0 else 0, "auto_shutdown_in": max(0, IDLE_TIMEOUT - round(time.time() - last_request_time)) if last_request_time > 0 and active_zone else None, "token_expires_in": round(_token_expiry - time.time()) if _token_expiry > 0 else None, } @app.get("/stats") async def get_stats(): stats = load_stats() hours = stats.get("gpu_seconds", 0) / 3600 gpu_cost = hours * GPU_COST_PER_HOUR total_cost = gpu_cost + FIXED_MONTHLY_COST return { "month": stats.get("month"), "gpu_hours": round(hours, 2), "gpu_minutes": round(hours * 60, 1), "estimated_cost_usd": round(total_cost, 2), "gpu_cost_usd": round(gpu_cost, 2), "fixed_cost_usd": FIXED_MONTHLY_COST, "monthly_limit_hours": MONTHLY_LIMIT_HOURS, "remaining_hours": round(MONTHLY_LIMIT_HOURS - hours, 2), "requests_count": stats.get("requests", 0), "active_zone": stats.get("active_zone", "none"), "cost_per_hour": GPU_COST_PER_HOUR, "recent_requests": request_history[:10], "zone_fallbacks": [ { "label": config["label"], "zone": config["zone"], "machine": config["machine_type"], "gpu": config["accelerator"], **zone_status.get(config["label"], {"status": "unknown", "last_tried": None, "last_error": None}), } for config in ZONE_FALLBACKS ], } @app.post("/gpu/start") async def gpu_start(): try: ip = await ensure_gpu_ready() label = active_zone["label"] if active_zone else "unknown" return {"status": "running", "ip": ip, "zone": label} except Exception as e: return JSONResponse({"error": str(e)}, status_code=503) @app.post("/gpu/stop") async def gpu_stop(): if not active_zone: return {"status": "no active instance"} try: await stop_instance_in_zone(active_zone["zone"], active_zone["instance"]) track_gpu_time() return {"status": "stopped", "zone": active_zone["label"]} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) DASHBOARD_HTML = Path(__file__).parent / "dashboard.html" @app.get("/", response_class=HTMLResponse) async def dashboard(): if DASHBOARD_HTML.exists(): return HTMLResponse(DASHBOARD_HTML.read_text(encoding="utf-8")) return HTMLResponse("
Place dashboard.html next to proxy.py
", status_code=404) @app.api_route("/v1/{path:path}", methods=["POST", "GET"]) async def llm_proxy(request: Request, path: str): global last_request_time, active_requests body = await request.body() headers = { k: v for k, v in request.headers.items() if k.lower() not in ("host", "transfer-encoding") } last_request_time = time.time() active_requests += 1 start_time = time.time() result_status = 200 try: ip = await ensure_ollama_ready() target = f"http://{ip}:{OLLAMA_PORT}/v1/{path}" log.info(f"Forwarding LLM request to {target}") async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client: resp = await client.request(request.method, target, content=body, headers=headers) last_request_time = time.time() result_status = resp.status_code return Response( content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type"), ) except httpx.ReadTimeout: result_status = 504 return JSONResponse({"error": "LLM timeout (5min)"}, status_code=504) except Exception as e: result_status = 502 log.error(f"LLM proxy error: {e}") return JSONResponse({"error": str(e)}, status_code=502) finally: active_requests -= 1 last_request_time = time.time() request_history.insert(0, { "time": time.strftime("%Y-%m-%dT%H:%M:%S"), "type": "LLM", "duration_sec": round(time.time() - start_time, 1), "status": result_status, "zone": active_zone["label"] if active_zone else "none", }) if len(request_history) > MAX_HISTORY: request_history.pop() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=9090)