742 lines
27 KiB
Python
742 lines
27 KiB
Python
"""DictIA ASR Proxy - Auto-start/stop GCP GPU for WhisperX + Ollama.
|
|
|
|
Uses Google Cloud Compute REST API directly (no gcloud CLI needed).
|
|
Proxies both ASR (WhisperX) and LLM (Ollama) requests.
|
|
Multi-zone fallback across Canada (Montreal + Toronto).
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
import httpx
|
|
import jwt as pyjwt
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import HTMLResponse, JSONResponse, Response
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
log = logging.getLogger("asr-proxy")
|
|
|
|
# Config — paths relative to this script's directory by default
|
|
SCRIPT_DIR = Path(__file__).parent
|
|
GCP_PROJECT = os.getenv("GCP_PROJECT", "speakr-gpu")
|
|
WHISPERX_PORT = int(os.getenv("WHISPERX_PORT", "9000"))
|
|
OLLAMA_PORT = int(os.getenv("OLLAMA_PORT", "11434"))
|
|
IDLE_TIMEOUT = int(os.getenv("IDLE_TIMEOUT", "300"))
|
|
CREDS_FILE = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", str(SCRIPT_DIR / "gcp-credentials.json"))
|
|
STATS_FILE = os.getenv("STATS_FILE", str(SCRIPT_DIR / "usage-stats.json"))
|
|
MONTHLY_LIMIT_HOURS = float(os.getenv("MONTHLY_LIMIT_HOURS", "30"))
|
|
# Real GCP cost per GPU-hour (g2-standard-4 + L4): GPU ($0.837) + vCPU ($0.151) + RAM ($0.069)
|
|
GPU_COST_PER_HOUR = float(os.getenv("GPU_COST_PER_HOUR", "1.06"))
|
|
# Fixed monthly costs: SSD disks ($5.66) + snapshots ($4.19) ≈ $9.85/month
|
|
FIXED_MONTHLY_COST = float(os.getenv("FIXED_MONTHLY_COST", "9.85"))
|
|
SNAPSHOT_NAME = "whisperx-gpu-snapshot"
|
|
HEALTH_POLL_INTERVAL = 5
|
|
BOOT_TIMEOUT = 300
|
|
|
|
# Zone fallback order — Canada only, Montreal first
|
|
ZONE_FALLBACKS = [
|
|
{
|
|
"zone": "northamerica-northeast1-b",
|
|
"instance": "whisperx-gpu-mtl1",
|
|
"machine_type": "g2-standard-4",
|
|
"accelerator": "nvidia-l4",
|
|
"accel_count": 1,
|
|
"label": "Montreal-b (L4)",
|
|
},
|
|
{
|
|
"zone": "northamerica-northeast1-c",
|
|
"instance": "whisperx-gpu-mtl2",
|
|
"machine_type": "n1-standard-4",
|
|
"accelerator": "nvidia-tesla-t4",
|
|
"accel_count": 1,
|
|
"label": "Montreal-c (T4)",
|
|
},
|
|
{
|
|
"zone": "northamerica-northeast2-a",
|
|
"instance": "whisperx-gpu-tor1",
|
|
"machine_type": "g2-standard-4",
|
|
"accelerator": "nvidia-l4",
|
|
"accel_count": 1,
|
|
"label": "Toronto-a (L4)",
|
|
},
|
|
{
|
|
"zone": "northamerica-northeast2-b",
|
|
"instance": "whisperx-gpu",
|
|
"machine_type": "g2-standard-4",
|
|
"accelerator": "nvidia-l4",
|
|
"accel_count": 1,
|
|
"label": "Toronto-b (L4)",
|
|
},
|
|
]
|
|
|
|
STARTUP_SCRIPT = """#!/bin/bash
|
|
systemctl start docker
|
|
sleep 5
|
|
docker start whisperx-asr 2>/dev/null || true
|
|
systemctl start ollama 2>/dev/null || true
|
|
"""
|
|
|
|
app = FastAPI(title="DictIA ASR Proxy")
|
|
|
|
# State
|
|
last_request_time = 0.0
|
|
active_requests = 0
|
|
gpu_ip: str | None = None
|
|
active_zone: dict | None = None
|
|
shutdown_task: asyncio.Task | None = None
|
|
|
|
# Request history tracking (in-memory, last 20 requests)
|
|
request_history: list[dict] = []
|
|
MAX_HISTORY = 20
|
|
|
|
# Zone status tracking
|
|
zone_status: dict[str, dict] = {}
|
|
|
|
# Startup lock and failure cooldown
|
|
_startup_lock: asyncio.Lock | None = None
|
|
_last_failure_time: float = 0
|
|
FAILURE_COOLDOWN = 180
|
|
|
|
# OAuth2 token cache
|
|
_access_token: str | None = None
|
|
_token_expiry: float = 0
|
|
|
|
|
|
# --- Usage Stats ---
|
|
|
|
def load_stats() -> dict:
|
|
try:
|
|
with open(STATS_FILE) as f:
|
|
return json.load(f)
|
|
except (FileNotFoundError, json.JSONDecodeError):
|
|
return {"gpu_seconds": 0, "month": time.strftime("%Y-%m"), "requests": 0, "last_start": 0}
|
|
|
|
|
|
def save_stats(stats: dict):
|
|
with open(STATS_FILE, "w") as f:
|
|
json.dump(stats, f, indent=2)
|
|
|
|
|
|
def track_gpu_time():
|
|
stats = load_stats()
|
|
current_month = time.strftime("%Y-%m")
|
|
if stats.get("month") != current_month:
|
|
stats = {"gpu_seconds": 0, "month": current_month, "requests": 0, "last_start": 0}
|
|
if stats.get("last_start", 0) > 0:
|
|
elapsed = time.time() - stats["last_start"]
|
|
stats["gpu_seconds"] += elapsed
|
|
stats["last_start"] = 0
|
|
save_stats(stats)
|
|
|
|
|
|
def check_budget() -> tuple[bool, float]:
|
|
stats = load_stats()
|
|
current_month = time.strftime("%Y-%m")
|
|
if stats.get("month") != current_month:
|
|
return True, 0.0
|
|
hours_used = stats.get("gpu_seconds", 0) / 3600
|
|
return hours_used < MONTHLY_LIMIT_HOURS, hours_used
|
|
|
|
|
|
# --- GCP Auth ---
|
|
|
|
async def get_access_token() -> str:
|
|
global _access_token, _token_expiry
|
|
if _access_token and time.time() < _token_expiry - 60:
|
|
return _access_token
|
|
with open(CREDS_FILE) as f:
|
|
creds = json.load(f)
|
|
cred_type = creds.get("type", "authorized_user")
|
|
async with httpx.AsyncClient() as client:
|
|
if cred_type == "service_account":
|
|
now = int(time.time())
|
|
payload = {
|
|
"iss": creds["client_email"],
|
|
"scope": "https://www.googleapis.com/auth/compute",
|
|
"aud": "https://oauth2.googleapis.com/token",
|
|
"iat": now,
|
|
"exp": now + 3600,
|
|
}
|
|
signed = pyjwt.encode(payload, creds["private_key"], algorithm="RS256")
|
|
resp = await client.post(
|
|
"https://oauth2.googleapis.com/token",
|
|
data={
|
|
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
|
"assertion": signed,
|
|
},
|
|
)
|
|
else:
|
|
resp = await client.post(
|
|
"https://oauth2.googleapis.com/token",
|
|
data={
|
|
"client_id": creds["client_id"],
|
|
"client_secret": creds["client_secret"],
|
|
"refresh_token": creds["refresh_token"],
|
|
"grant_type": "refresh_token",
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
_access_token = data["access_token"]
|
|
_token_expiry = time.time() + data.get("expires_in", 3600)
|
|
log.info(f"Refreshed GCP access token ({cred_type})")
|
|
return _access_token
|
|
|
|
|
|
# --- GCP Compute API ---
|
|
|
|
COMPUTE_BASE = "https://compute.googleapis.com/compute/v1"
|
|
|
|
|
|
async def gcp_api(method: str, url: str, **kwargs) -> httpx.Response:
|
|
token = await get_access_token()
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
resp = await client.request(
|
|
method, url,
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
**kwargs,
|
|
)
|
|
return resp
|
|
|
|
|
|
async def get_instance_info(zone: str, instance: str) -> dict | None:
|
|
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}"
|
|
resp = await gcp_api("GET", url)
|
|
if resp.status_code == 404:
|
|
return None
|
|
if resp.status_code >= 400:
|
|
log.error(f"GCP API error {resp.status_code}: {resp.text}")
|
|
return None
|
|
return resp.json()
|
|
|
|
|
|
def extract_ip(instance_data: dict) -> str:
|
|
interfaces = instance_data.get("networkInterfaces", [])
|
|
if interfaces:
|
|
access = interfaces[0].get("accessConfigs", [])
|
|
if access:
|
|
return access[0].get("natIP", "")
|
|
return ""
|
|
|
|
|
|
async def start_instance_in_zone(zone: str, instance: str) -> bool:
|
|
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}/start"
|
|
resp = await gcp_api("POST", url)
|
|
if resp.status_code < 400:
|
|
log.info(f"Start requested: {instance} in {zone}")
|
|
return True
|
|
log.warning(f"Failed to start {instance} in {zone}: {resp.status_code} {resp.text}")
|
|
return False
|
|
|
|
|
|
async def stop_instance_in_zone(zone: str, instance: str):
|
|
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}/stop"
|
|
resp = await gcp_api("POST", url)
|
|
if resp.status_code < 400:
|
|
log.info(f"Stop requested: {instance} in {zone}")
|
|
else:
|
|
log.error(f"Failed to stop {instance} in {zone}: {resp.status_code}")
|
|
|
|
|
|
async def create_instance_from_snapshot(config: dict) -> bool:
|
|
zone = config["zone"]
|
|
instance = config["instance"]
|
|
machine = config["machine_type"]
|
|
accel = config["accelerator"]
|
|
accel_count = config["accel_count"]
|
|
|
|
log.info(f"Creating {instance} in {zone} from snapshot...")
|
|
|
|
body = {
|
|
"name": instance,
|
|
"machineType": f"zones/{zone}/machineTypes/{machine}",
|
|
"disks": [{
|
|
"boot": True,
|
|
"autoDelete": True,
|
|
"initializeParams": {
|
|
"diskSizeGb": "50",
|
|
"diskType": f"zones/{zone}/diskTypes/pd-ssd",
|
|
"sourceSnapshot": f"global/snapshots/{SNAPSHOT_NAME}",
|
|
},
|
|
}],
|
|
"networkInterfaces": [{
|
|
"network": "global/networks/default",
|
|
"accessConfigs": [{"type": "ONE_TO_ONE_NAT", "name": "External NAT"}],
|
|
}],
|
|
"guestAccelerators": [{
|
|
"acceleratorType": f"zones/{zone}/acceleratorTypes/{accel}",
|
|
"acceleratorCount": accel_count,
|
|
}],
|
|
"scheduling": {
|
|
"onHostMaintenance": "TERMINATE",
|
|
"automaticRestart": False,
|
|
},
|
|
"tags": {"items": ["whisperx-gpu"]},
|
|
"metadata": {
|
|
"items": [{"key": "startup-script", "value": STARTUP_SCRIPT}],
|
|
},
|
|
}
|
|
|
|
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances"
|
|
resp = await gcp_api("POST", url, json=body)
|
|
|
|
if resp.status_code < 400:
|
|
log.info(f"Created {instance} in {zone}")
|
|
return True
|
|
|
|
error_text = resp.text
|
|
if "ZONE_RESOURCE_POOL_EXHAUSTED" in error_text:
|
|
log.warning(f"No capacity in {zone} -- skipping")
|
|
elif "QUOTA" in error_text.upper():
|
|
log.warning(f"Quota exceeded for {zone}: {error_text[:200]}")
|
|
else:
|
|
log.error(f"Failed to create in {zone}: {resp.status_code} {error_text[:200]}")
|
|
return False
|
|
|
|
|
|
# --- Core Logic ---
|
|
|
|
async def wait_for_running(zone: str, instance: str, timeout: int = 120, grace: int = 15) -> bool:
|
|
gone_count = 0
|
|
start_time = time.time()
|
|
for _ in range(timeout // 5):
|
|
info = await get_instance_info(zone, instance)
|
|
if info and info.get("status") == "RUNNING":
|
|
return True
|
|
status = info.get("status", "UNKNOWN") if info else "GONE"
|
|
elapsed = time.time() - start_time
|
|
if status == "GONE":
|
|
gone_count += 1
|
|
if gone_count >= 2:
|
|
log.warning(f"{instance} in {zone}: instance disappeared (no capacity)")
|
|
return False
|
|
if status in ("STOPPING",):
|
|
log.warning(f"{instance} in {zone}: status {status} (no capacity)")
|
|
return False
|
|
if status in ("TERMINATED", "STOPPED") and elapsed > grace:
|
|
log.warning(f"{instance} in {zone}: status {status} after {elapsed:.0f}s (no capacity)")
|
|
return False
|
|
await asyncio.sleep(5)
|
|
return False
|
|
|
|
|
|
async def delete_instance(zone: str, instance: str):
|
|
url = f"{COMPUTE_BASE}/projects/{GCP_PROJECT}/zones/{zone}/instances/{instance}"
|
|
resp = await gcp_api("DELETE", url)
|
|
if resp.status_code < 400:
|
|
log.info(f"Deleted {instance} in {zone} to free quota")
|
|
elif resp.status_code == 404:
|
|
pass
|
|
else:
|
|
log.warning(f"Failed to delete {instance} in {zone}: {resp.status_code}")
|
|
|
|
|
|
async def ensure_gpu_running() -> str:
|
|
global gpu_ip, active_zone, _last_failure_time
|
|
|
|
if _last_failure_time > 0:
|
|
remaining = FAILURE_COOLDOWN - (time.time() - _last_failure_time)
|
|
if remaining > 0:
|
|
log.info(f"GPU cooldown active ({int(remaining)}s remaining), waiting...")
|
|
await asyncio.sleep(remaining)
|
|
_last_failure_time = 0
|
|
|
|
async with _startup_lock:
|
|
ok, hours = check_budget()
|
|
if not ok:
|
|
raise RuntimeError(f"Monthly GPU limit reached ({hours:.1f}h / {MONTHLY_LIMIT_HOURS}h)")
|
|
|
|
if active_zone:
|
|
info = await get_instance_info(active_zone["zone"], active_zone["instance"])
|
|
if info and info.get("status") == "RUNNING":
|
|
gpu_ip = extract_ip(info)
|
|
if gpu_ip:
|
|
return gpu_ip
|
|
|
|
errors = []
|
|
|
|
for config in ZONE_FALLBACKS:
|
|
zone = config["zone"]
|
|
instance = config["instance"]
|
|
label = config["label"]
|
|
|
|
log.info(f"Trying {label}...")
|
|
info = await get_instance_info(zone, instance)
|
|
|
|
if info is None:
|
|
created = await create_instance_from_snapshot(config)
|
|
if not created:
|
|
zone_status[label] = {
|
|
"status": "no_capacity",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": "no capacity",
|
|
}
|
|
errors.append(f"{label}: no capacity")
|
|
continue
|
|
if not await wait_for_running(zone, instance, grace=30):
|
|
zone_status[label] = {
|
|
"status": "error",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": "created but failed to start",
|
|
}
|
|
errors.append(f"{label}: created but failed to start")
|
|
await delete_instance(zone, instance)
|
|
await asyncio.sleep(3)
|
|
continue
|
|
else:
|
|
status = info.get("status", "UNKNOWN")
|
|
|
|
if status == "RUNNING":
|
|
pass
|
|
elif status in ("TERMINATED", "STOPPED"):
|
|
zone_status[label] = {
|
|
"status": "starting",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": None,
|
|
}
|
|
started = await start_instance_in_zone(zone, instance)
|
|
if not started:
|
|
zone_status[label] = {
|
|
"status": "error",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": "start rejected",
|
|
}
|
|
errors.append(f"{label}: start rejected")
|
|
continue
|
|
if not await wait_for_running(zone, instance, grace=20):
|
|
zone_status[label] = {
|
|
"status": "error",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": "didn't reach RUNNING",
|
|
}
|
|
errors.append(f"{label}: didn't reach RUNNING")
|
|
continue
|
|
elif status in ("STAGING", "PROVISIONING"):
|
|
zone_status[label] = {
|
|
"status": "starting",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": None,
|
|
}
|
|
if not await wait_for_running(zone, instance):
|
|
zone_status[label] = {
|
|
"status": "error",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": f"stuck in {status}",
|
|
}
|
|
errors.append(f"{label}: stuck in {status}")
|
|
continue
|
|
elif status == "STOPPING":
|
|
log.info(f"{label}: STOPPING, deleting to free quota")
|
|
await delete_instance(zone, instance)
|
|
await asyncio.sleep(3)
|
|
zone_status[label] = {
|
|
"status": "error",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": "was STOPPING, deleted",
|
|
}
|
|
errors.append(f"{label}: was STOPPING, deleted")
|
|
continue
|
|
|
|
info = await get_instance_info(zone, instance)
|
|
if info and info.get("status") == "RUNNING":
|
|
gpu_ip = extract_ip(info)
|
|
if gpu_ip:
|
|
active_zone = config
|
|
_last_failure_time = 0
|
|
zone_status[label] = {
|
|
"status": "running",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": None,
|
|
}
|
|
stats = load_stats()
|
|
stats["last_start"] = time.time()
|
|
stats["requests"] = stats.get("requests", 0) + 1
|
|
stats["active_zone"] = label
|
|
save_stats(stats)
|
|
log.info(f"GPU ready in {label}, IP: {gpu_ip}")
|
|
return gpu_ip
|
|
|
|
zone_status[label] = {
|
|
"status": "error",
|
|
"last_tried": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"last_error": "running but no IP",
|
|
}
|
|
errors.append(f"{label}: running but no IP")
|
|
|
|
_last_failure_time = time.time()
|
|
raise RuntimeError(
|
|
f"No GPU available in any Canadian zone. Tried: {'; '.join(errors)}"
|
|
)
|
|
|
|
|
|
async def ensure_gpu_ready() -> str:
|
|
ip = await ensure_gpu_running()
|
|
url = f"http://{ip}:{WHISPERX_PORT}/health"
|
|
log.info(f"Waiting for WhisperX at {url}...")
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
for _ in range(BOOT_TIMEOUT // HEALTH_POLL_INTERVAL):
|
|
try:
|
|
resp = await client.get(url)
|
|
if resp.status_code == 200:
|
|
log.info("WhisperX is healthy!")
|
|
return ip
|
|
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout):
|
|
pass
|
|
await asyncio.sleep(HEALTH_POLL_INTERVAL)
|
|
raise RuntimeError("WhisperX did not become healthy in time")
|
|
|
|
|
|
async def ensure_ollama_ready() -> str:
|
|
ip = await ensure_gpu_running()
|
|
url = f"http://{ip}:{OLLAMA_PORT}/api/tags"
|
|
log.info(f"Waiting for Ollama at {url}...")
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
for _ in range(BOOT_TIMEOUT // HEALTH_POLL_INTERVAL):
|
|
try:
|
|
resp = await client.get(url)
|
|
if resp.status_code == 200:
|
|
log.info("Ollama is healthy!")
|
|
return ip
|
|
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout):
|
|
pass
|
|
await asyncio.sleep(HEALTH_POLL_INTERVAL)
|
|
raise RuntimeError("Ollama did not become healthy in time")
|
|
|
|
|
|
async def idle_shutdown_loop():
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
if last_request_time == 0 or active_zone is None:
|
|
continue
|
|
if active_requests > 0:
|
|
continue
|
|
elapsed = time.time() - last_request_time
|
|
if elapsed >= IDLE_TIMEOUT:
|
|
try:
|
|
zone = active_zone["zone"]
|
|
instance = active_zone["instance"]
|
|
label = active_zone["label"]
|
|
info = await get_instance_info(zone, instance)
|
|
if info and info.get("status") == "RUNNING":
|
|
log.info(f"Idle {int(elapsed)}s -- stopping {label}")
|
|
await stop_instance_in_zone(zone, instance)
|
|
track_gpu_time()
|
|
except Exception as e:
|
|
log.error(f"Error stopping: {e}")
|
|
|
|
|
|
# --- Endpoints ---
|
|
|
|
@app.on_event("startup")
|
|
async def on_startup():
|
|
global shutdown_task, _startup_lock
|
|
_startup_lock = asyncio.Lock()
|
|
await get_access_token()
|
|
shutdown_task = asyncio.create_task(idle_shutdown_loop())
|
|
zones = ", ".join(c["label"] for c in ZONE_FALLBACKS)
|
|
log.info(f"DictIA ASR Proxy started. Zones: [{zones}]. Idle: {IDLE_TIMEOUT}s, limit: {MONTHLY_LIMIT_HOURS}h")
|
|
|
|
|
|
@app.post("/asr")
|
|
async def asr_proxy(request: Request):
|
|
global last_request_time, active_requests
|
|
|
|
body = await request.body()
|
|
headers = {
|
|
k: v for k, v in request.headers.items()
|
|
if k.lower() not in ("host", "transfer-encoding")
|
|
}
|
|
|
|
last_request_time = time.time()
|
|
active_requests += 1
|
|
start_time = time.time()
|
|
result_status = 200
|
|
try:
|
|
ip = await ensure_gpu_ready()
|
|
target = f"http://{ip}:{WHISPERX_PORT}/asr"
|
|
log.info(f"Forwarding {len(body)} bytes to {target}")
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(7200.0)) as client:
|
|
resp = await client.post(target, content=body, headers=headers)
|
|
last_request_time = time.time()
|
|
result_status = resp.status_code
|
|
ct = resp.headers.get("content-type", "")
|
|
if "application/json" in ct:
|
|
return JSONResponse(content=resp.json(), status_code=resp.status_code)
|
|
else:
|
|
return JSONResponse(content=resp.text, status_code=resp.status_code)
|
|
except httpx.ReadTimeout:
|
|
result_status = 504
|
|
return JSONResponse({"error": "Transcription timeout (2h)"}, status_code=504)
|
|
except Exception as e:
|
|
result_status = 502
|
|
log.error(f"Proxy error: {e}")
|
|
return JSONResponse({"error": str(e)}, status_code=502)
|
|
finally:
|
|
active_requests -= 1
|
|
last_request_time = time.time()
|
|
request_history.insert(0, {
|
|
"time": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"type": "ASR",
|
|
"duration_sec": round(time.time() - start_time, 1),
|
|
"status": result_status,
|
|
"zone": active_zone["label"] if active_zone else "none",
|
|
})
|
|
if len(request_history) > MAX_HISTORY:
|
|
request_history.pop()
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
zone_label = active_zone["label"] if active_zone else "none"
|
|
gpu_status = "unknown"
|
|
if active_zone:
|
|
try:
|
|
info = await get_instance_info(active_zone["zone"], active_zone["instance"])
|
|
gpu_status = info.get("status", "unknown") if info else "not_found"
|
|
except Exception:
|
|
pass
|
|
ok, hours = check_budget()
|
|
stats = load_stats()
|
|
return {
|
|
"proxy": "healthy",
|
|
"gpu_instance": gpu_status,
|
|
"gpu_zone": zone_label,
|
|
"active_requests": active_requests,
|
|
"idle_timeout": IDLE_TIMEOUT,
|
|
"usage": {
|
|
"month": stats.get("month"),
|
|
"gpu_hours": round(hours, 2),
|
|
"gpu_limit_hours": MONTHLY_LIMIT_HOURS,
|
|
"requests_count": stats.get("requests", 0),
|
|
"budget_ok": ok,
|
|
},
|
|
"gpu_ip": gpu_ip,
|
|
"machine_type": active_zone.get("machine_type", "unknown") if active_zone else "unknown",
|
|
"gpu_model": active_zone.get("accelerator", "unknown") if active_zone else "unknown",
|
|
"idle_seconds": round(time.time() - last_request_time) if last_request_time > 0 else 0,
|
|
"auto_shutdown_in": max(0, IDLE_TIMEOUT - round(time.time() - last_request_time)) if last_request_time > 0 and active_zone else None,
|
|
"token_expires_in": round(_token_expiry - time.time()) if _token_expiry > 0 else None,
|
|
}
|
|
|
|
|
|
@app.get("/stats")
|
|
async def get_stats():
|
|
stats = load_stats()
|
|
hours = stats.get("gpu_seconds", 0) / 3600
|
|
gpu_cost = hours * GPU_COST_PER_HOUR
|
|
total_cost = gpu_cost + FIXED_MONTHLY_COST
|
|
return {
|
|
"month": stats.get("month"),
|
|
"gpu_hours": round(hours, 2),
|
|
"gpu_minutes": round(hours * 60, 1),
|
|
"estimated_cost_usd": round(total_cost, 2),
|
|
"gpu_cost_usd": round(gpu_cost, 2),
|
|
"fixed_cost_usd": FIXED_MONTHLY_COST,
|
|
"monthly_limit_hours": MONTHLY_LIMIT_HOURS,
|
|
"remaining_hours": round(MONTHLY_LIMIT_HOURS - hours, 2),
|
|
"requests_count": stats.get("requests", 0),
|
|
"active_zone": stats.get("active_zone", "none"),
|
|
"cost_per_hour": GPU_COST_PER_HOUR,
|
|
"recent_requests": request_history[:10],
|
|
"zone_fallbacks": [
|
|
{
|
|
"label": config["label"],
|
|
"zone": config["zone"],
|
|
"machine": config["machine_type"],
|
|
"gpu": config["accelerator"],
|
|
**zone_status.get(config["label"], {"status": "unknown", "last_tried": None, "last_error": None}),
|
|
}
|
|
for config in ZONE_FALLBACKS
|
|
],
|
|
}
|
|
|
|
|
|
@app.post("/gpu/start")
|
|
async def gpu_start():
|
|
try:
|
|
ip = await ensure_gpu_ready()
|
|
label = active_zone["label"] if active_zone else "unknown"
|
|
return {"status": "running", "ip": ip, "zone": label}
|
|
except Exception as e:
|
|
return JSONResponse({"error": str(e)}, status_code=503)
|
|
|
|
|
|
@app.post("/gpu/stop")
|
|
async def gpu_stop():
|
|
if not active_zone:
|
|
return {"status": "no active instance"}
|
|
try:
|
|
await stop_instance_in_zone(active_zone["zone"], active_zone["instance"])
|
|
track_gpu_time()
|
|
return {"status": "stopped", "zone": active_zone["label"]}
|
|
except Exception as e:
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
|
|
DASHBOARD_HTML = Path(__file__).parent / "dashboard.html"
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def dashboard():
|
|
if DASHBOARD_HTML.exists():
|
|
return HTMLResponse(DASHBOARD_HTML.read_text(encoding="utf-8"))
|
|
return HTMLResponse("<h1>Dashboard not found</h1><p>Place dashboard.html next to proxy.py</p>", status_code=404)
|
|
|
|
|
|
@app.api_route("/v1/{path:path}", methods=["POST", "GET"])
|
|
async def llm_proxy(request: Request, path: str):
|
|
global last_request_time, active_requests
|
|
|
|
body = await request.body()
|
|
headers = {
|
|
k: v for k, v in request.headers.items()
|
|
if k.lower() not in ("host", "transfer-encoding")
|
|
}
|
|
|
|
last_request_time = time.time()
|
|
active_requests += 1
|
|
start_time = time.time()
|
|
result_status = 200
|
|
try:
|
|
ip = await ensure_ollama_ready()
|
|
target = f"http://{ip}:{OLLAMA_PORT}/v1/{path}"
|
|
log.info(f"Forwarding LLM request to {target}")
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
resp = await client.request(request.method, target, content=body, headers=headers)
|
|
last_request_time = time.time()
|
|
result_status = resp.status_code
|
|
return Response(
|
|
content=resp.content,
|
|
status_code=resp.status_code,
|
|
media_type=resp.headers.get("content-type"),
|
|
)
|
|
except httpx.ReadTimeout:
|
|
result_status = 504
|
|
return JSONResponse({"error": "LLM timeout (5min)"}, status_code=504)
|
|
except Exception as e:
|
|
result_status = 502
|
|
log.error(f"LLM proxy error: {e}")
|
|
return JSONResponse({"error": str(e)}, status_code=502)
|
|
finally:
|
|
active_requests -= 1
|
|
last_request_time = time.time()
|
|
request_history.insert(0, {
|
|
"time": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"type": "LLM",
|
|
"duration_sec": round(time.time() - start_time, 1),
|
|
"status": result_status,
|
|
"zone": active_zone["label"] if active_zone else "none",
|
|
})
|
|
if len(request_history) > MAX_HISTORY:
|
|
request_history.pop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=9090)
|