Model Training
FastAPI
Server Monitoring
Machine Learning
API Integration

How to monitor the status of model training running on the server via fast-api

Master System Design with Codemia

Enhance your system design skills with over 120 practice problems, detailed solutions, and hands-on exercises.

Introduction

To monitor model training via FastAPI, run the training in a background task (using BackgroundTasks, asyncio, or a task queue like Celery) and expose status endpoints that clients can poll or stream via Server-Sent Events (SSE). The training function updates a shared state object (dictionary, database row, or Redis key) with metrics like current epoch, loss, accuracy, and progress percentage. The client queries GET /training/{job_id}/status to retrieve the latest state. This pattern decouples the long-running training from the HTTP request-response cycle.

Basic Setup with Background Tasks

python
1from fastapi import FastAPI, BackgroundTasks
2from pydantic import BaseModel
3from uuid import uuid4
4from datetime import datetime
5
6app = FastAPI()
7
8# In-memory job store (use Redis or a database in production)
9training_jobs = {}
10
11class TrainingStatus(BaseModel):
12    job_id: str
13    status: str  # "queued", "running", "completed", "failed"
14    epoch: int = 0
15    total_epochs: int = 0
16    loss: float = 0.0
17    accuracy: float = 0.0
18    started_at: datetime | None = None
19    completed_at: datetime | None = None
20    error: str | None = None
21
22def train_model(job_id: str, config: dict):
23    """Simulate model training with status updates."""
24    import time
25    training_jobs[job_id].status = "running"
26    training_jobs[job_id].started_at = datetime.now()
27
28    total_epochs = config.get("epochs", 10)
29    training_jobs[job_id].total_epochs = total_epochs
30
31    for epoch in range(1, total_epochs + 1):
32        time.sleep(2)  # Simulate epoch training time
33        training_jobs[job_id].epoch = epoch
34        training_jobs[job_id].loss = 1.0 / epoch
35        training_jobs[job_id].accuracy = 1 - (1.0 / (epoch + 1))
36
37    training_jobs[job_id].status = "completed"
38    training_jobs[job_id].completed_at = datetime.now()
39
40@app.post("/training/start")
41async def start_training(background_tasks: BackgroundTasks):
42    job_id = str(uuid4())
43    training_jobs[job_id] = TrainingStatus(
44        job_id=job_id, status="queued", total_epochs=10
45    )
46    background_tasks.add_task(train_model, job_id, {"epochs": 10})
47    return {"job_id": job_id, "message": "Training started"}
48
49@app.get("/training/{job_id}/status")
50async def get_status(job_id: str):
51    if job_id not in training_jobs:
52        return {"error": "Job not found"}, 404
53    return training_jobs[job_id]

Using asyncio for Concurrent Training

For better concurrency with async-compatible training loops:

python
1import asyncio
2from fastapi import FastAPI
3
4app = FastAPI()
5training_jobs = {}
6
7async def async_train(job_id: str, epochs: int):
8    training_jobs[job_id]["status"] = "running"
9
10    for epoch in range(1, epochs + 1):
11        await asyncio.sleep(1)  # Replace with actual async training step
12        training_jobs[job_id].update({
13            "epoch": epoch,
14            "loss": round(1.0 / epoch, 4),
15            "progress": round(epoch / epochs * 100, 1)
16        })
17
18    training_jobs[job_id]["status"] = "completed"
19
20@app.post("/training/start")
21async def start_training(epochs: int = 10):
22    job_id = str(uuid4())
23    training_jobs[job_id] = {"status": "queued", "epoch": 0, "progress": 0}
24    asyncio.create_task(async_train(job_id, epochs))
25    return {"job_id": job_id}
26
27@app.get("/training/{job_id}")
28async def get_status(job_id: str):
29    return training_jobs.get(job_id, {"error": "Not found"})

Server-Sent Events for Real-Time Updates

SSE pushes updates to the client without polling:

python
1from fastapi import FastAPI
2from fastapi.responses import StreamingResponse
3import asyncio
4import json
5
6app = FastAPI()
7training_jobs = {}
8
9async def event_stream(job_id: str):
10    while True:
11        if job_id in training_jobs:
12            data = json.dumps(training_jobs[job_id])
13            yield f"data: {data}\n\n"
14
15            if training_jobs[job_id].get("status") in ("completed", "failed"):
16                break
17        await asyncio.sleep(1)
18
19@app.get("/training/{job_id}/stream")
20async def stream_status(job_id: str):
21    return StreamingResponse(
22        event_stream(job_id),
23        media_type="text/event-stream",
24        headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
25    )

Client-side consumption:

javascript
1const eventSource = new EventSource('/training/abc123/stream');
2eventSource.onmessage = (event) => {
3    const status = JSON.parse(event.data);
4    console.log(`Epoch ${status.epoch}: loss=${status.loss}`);
5    if (status.status === 'completed') {
6        eventSource.close();
7    }
8};

Production Setup with Celery

For production workloads, offload training to a Celery worker:

python
1# tasks.py
2from celery import Celery
3
4celery_app = Celery('tasks', broker='redis://localhost:6379/0',
5                     backend='redis://localhost:6379/0')
6
7@celery_app.task(bind=True)
8def train_model_task(self, config):
9    total_epochs = config['epochs']
10
11    for epoch in range(1, total_epochs + 1):
12        # Actual training code here
13        loss = 1.0 / epoch
14        self.update_state(state='PROGRESS', meta={
15            'epoch': epoch,
16            'total_epochs': total_epochs,
17            'loss': round(loss, 4),
18            'progress': round(epoch / total_epochs * 100, 1)
19        })
20
21    return {'status': 'completed', 'final_loss': round(1.0 / total_epochs, 4)}
python
1# main.py
2from fastapi import FastAPI
3from tasks import train_model_task, celery_app
4
5app = FastAPI()
6
7@app.post("/training/start")
8async def start_training(epochs: int = 10):
9    task = train_model_task.delay({"epochs": epochs})
10    return {"task_id": task.id}
11
12@app.get("/training/{task_id}/status")
13async def get_status(task_id: str):
14    result = celery_app.AsyncResult(task_id)
15    if result.state == 'PROGRESS':
16        return {"status": "running", **result.info}
17    elif result.state == 'SUCCESS':
18        return {"status": "completed", **result.result}
19    elif result.state == 'FAILURE':
20        return {"status": "failed", "error": str(result.result)}
21    return {"status": result.state.lower()}

WebSocket for Bidirectional Communication

python
1from fastapi import FastAPI, WebSocket
2
3app = FastAPI()
4
5@app.websocket("/training/{job_id}/ws")
6async def training_websocket(websocket: WebSocket, job_id: str):
7    await websocket.accept()
8
9    while True:
10        if job_id in training_jobs:
11            await websocket.send_json(training_jobs[job_id])
12            if training_jobs[job_id].get("status") in ("completed", "failed"):
13                break
14
15        # Client can send commands like "stop" or "pause"
16        try:
17            data = await asyncio.wait_for(websocket.receive_text(), timeout=2.0)
18            if data == "stop":
19                training_jobs[job_id]["status"] = "cancelled"
20                break
21        except asyncio.TimeoutError:
22            continue
23
24    await websocket.close()

Common Pitfalls

  • Blocking the event loop with synchronous training: CPU-bound training code (PyTorch, scikit-learn) blocks FastAPI's async event loop. Run synchronous training in BackgroundTasks or a separate process (Celery), not directly in an async def endpoint. Use asyncio.to_thread() if you must call sync code from async context.
  • Using in-memory dict for job state in multi-worker deployments: An in-memory dictionary is only visible to the worker process that created it. With multiple Uvicorn workers or Gunicorn processes, requests may hit a different worker that has no knowledge of the job. Use Redis or a database for shared state.
  • Not handling training failures: If the training function raises an exception, the job status stays as "running" forever. Wrap the training function in try/except and update the status to "failed" with the error message in the except block.
  • SSE connections timing out behind reverse proxies: Nginx and other reverse proxies have default timeouts (60s) that close idle SSE connections. Configure proxy_read_timeout 3600s in Nginx or send periodic heartbeat comments (": heartbeat\n\n") to keep the connection alive.
  • Memory leaks from accumulating job records: Without cleanup, the training_jobs dictionary grows indefinitely. Implement a TTL-based cleanup (delete completed jobs after N hours) or use Redis with EXPIRE to automatically remove stale entries.

Summary

  • Run training in a background task and expose a GET /status endpoint for polling
  • Use Server-Sent Events (text/event-stream) for real-time push updates without polling
  • For production, use Celery with Redis backend and self.update_state() for progress tracking
  • Store job state in Redis or a database — not in-memory — when running multiple workers
  • Wrap training code in try/except to update status on failure and implement job cleanup for completed tasks

Course illustration
Course illustration

All Rights Reserved.