Introduction
To monitor model training via FastAPI, run the training in a background task (using BackgroundTasks, asyncio, or a task queue like Celery) and expose status endpoints that clients can poll or stream via Server-Sent Events (SSE). The training function updates a shared state object (dictionary, database row, or Redis key) with metrics like current epoch, loss, accuracy, and progress percentage. The client queries GET /training/{job_id}/status to retrieve the latest state. This pattern decouples the long-running training from the HTTP request-response cycle.
Basic Setup with Background Tasks
1from fastapi import FastAPI, BackgroundTasks
2from pydantic import BaseModel
3from uuid import uuid4
4from datetime import datetime
5
6app = FastAPI()
7
8# In-memory job store (use Redis or a database in production)
9training_jobs = {}
10
11class TrainingStatus(BaseModel):
12 job_id: str
13 status: str # "queued", "running", "completed", "failed"
14 epoch: int = 0
15 total_epochs: int = 0
16 loss: float = 0.0
17 accuracy: float = 0.0
18 started_at: datetime | None = None
19 completed_at: datetime | None = None
20 error: str | None = None
21
22def train_model(job_id: str, config: dict):
23 """Simulate model training with status updates."""
24 import time
25 training_jobs[job_id].status = "running"
26 training_jobs[job_id].started_at = datetime.now()
27
28 total_epochs = config.get("epochs", 10)
29 training_jobs[job_id].total_epochs = total_epochs
30
31 for epoch in range(1, total_epochs + 1):
32 time.sleep(2) # Simulate epoch training time
33 training_jobs[job_id].epoch = epoch
34 training_jobs[job_id].loss = 1.0 / epoch
35 training_jobs[job_id].accuracy = 1 - (1.0 / (epoch + 1))
36
37 training_jobs[job_id].status = "completed"
38 training_jobs[job_id].completed_at = datetime.now()
39
40@app.post("/training/start")
41async def start_training(background_tasks: BackgroundTasks):
42 job_id = str(uuid4())
43 training_jobs[job_id] = TrainingStatus(
44 job_id=job_id, status="queued", total_epochs=10
45 )
46 background_tasks.add_task(train_model, job_id, {"epochs": 10})
47 return {"job_id": job_id, "message": "Training started"}
48
49@app.get("/training/{job_id}/status")
50async def get_status(job_id: str):
51 if job_id not in training_jobs:
52 return {"error": "Job not found"}, 404
53 return training_jobs[job_id]
Using asyncio for Concurrent Training
For better concurrency with async-compatible training loops:
1import asyncio
2from fastapi import FastAPI
3
4app = FastAPI()
5training_jobs = {}
6
7async def async_train(job_id: str, epochs: int):
8 training_jobs[job_id]["status"] = "running"
9
10 for epoch in range(1, epochs + 1):
11 await asyncio.sleep(1) # Replace with actual async training step
12 training_jobs[job_id].update({
13 "epoch": epoch,
14 "loss": round(1.0 / epoch, 4),
15 "progress": round(epoch / epochs * 100, 1)
16 })
17
18 training_jobs[job_id]["status"] = "completed"
19
20@app.post("/training/start")
21async def start_training(epochs: int = 10):
22 job_id = str(uuid4())
23 training_jobs[job_id] = {"status": "queued", "epoch": 0, "progress": 0}
24 asyncio.create_task(async_train(job_id, epochs))
25 return {"job_id": job_id}
26
27@app.get("/training/{job_id}")
28async def get_status(job_id: str):
29 return training_jobs.get(job_id, {"error": "Not found"})
Server-Sent Events for Real-Time Updates
SSE pushes updates to the client without polling:
1from fastapi import FastAPI
2from fastapi.responses import StreamingResponse
3import asyncio
4import json
5
6app = FastAPI()
7training_jobs = {}
8
9async def event_stream(job_id: str):
10 while True:
11 if job_id in training_jobs:
12 data = json.dumps(training_jobs[job_id])
13 yield f"data: {data}\n\n"
14
15 if training_jobs[job_id].get("status") in ("completed", "failed"):
16 break
17 await asyncio.sleep(1)
18
19@app.get("/training/{job_id}/stream")
20async def stream_status(job_id: str):
21 return StreamingResponse(
22 event_stream(job_id),
23 media_type="text/event-stream",
24 headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
25 )
Client-side consumption:
1const eventSource = new EventSource('/training/abc123/stream');
2eventSource.onmessage = (event) => {
3 const status = JSON.parse(event.data);
4 console.log(`Epoch ${status.epoch}: loss=${status.loss}`);
5 if (status.status === 'completed') {
6 eventSource.close();
7 }
8};
Production Setup with Celery
For production workloads, offload training to a Celery worker:
1# tasks.py
2from celery import Celery
3
4celery_app = Celery('tasks', broker='redis://localhost:6379/0',
5 backend='redis://localhost:6379/0')
6
7@celery_app.task(bind=True)
8def train_model_task(self, config):
9 total_epochs = config['epochs']
10
11 for epoch in range(1, total_epochs + 1):
12 # Actual training code here
13 loss = 1.0 / epoch
14 self.update_state(state='PROGRESS', meta={
15 'epoch': epoch,
16 'total_epochs': total_epochs,
17 'loss': round(loss, 4),
18 'progress': round(epoch / total_epochs * 100, 1)
19 })
20
21 return {'status': 'completed', 'final_loss': round(1.0 / total_epochs, 4)}
1# main.py
2from fastapi import FastAPI
3from tasks import train_model_task, celery_app
4
5app = FastAPI()
6
7@app.post("/training/start")
8async def start_training(epochs: int = 10):
9 task = train_model_task.delay({"epochs": epochs})
10 return {"task_id": task.id}
11
12@app.get("/training/{task_id}/status")
13async def get_status(task_id: str):
14 result = celery_app.AsyncResult(task_id)
15 if result.state == 'PROGRESS':
16 return {"status": "running", **result.info}
17 elif result.state == 'SUCCESS':
18 return {"status": "completed", **result.result}
19 elif result.state == 'FAILURE':
20 return {"status": "failed", "error": str(result.result)}
21 return {"status": result.state.lower()}
WebSocket for Bidirectional Communication
1from fastapi import FastAPI, WebSocket
2
3app = FastAPI()
4
5@app.websocket("/training/{job_id}/ws")
6async def training_websocket(websocket: WebSocket, job_id: str):
7 await websocket.accept()
8
9 while True:
10 if job_id in training_jobs:
11 await websocket.send_json(training_jobs[job_id])
12 if training_jobs[job_id].get("status") in ("completed", "failed"):
13 break
14
15 # Client can send commands like "stop" or "pause"
16 try:
17 data = await asyncio.wait_for(websocket.receive_text(), timeout=2.0)
18 if data == "stop":
19 training_jobs[job_id]["status"] = "cancelled"
20 break
21 except asyncio.TimeoutError:
22 continue
23
24 await websocket.close()
Common Pitfalls
Blocking the event loop with synchronous training: CPU-bound training code (PyTorch, scikit-learn) blocks FastAPI's async event loop. Run synchronous training in BackgroundTasks or a separate process (Celery), not directly in an async def endpoint. Use asyncio.to_thread() if you must call sync code from async context.
Using in-memory dict for job state in multi-worker deployments: An in-memory dictionary is only visible to the worker process that created it. With multiple Uvicorn workers or Gunicorn processes, requests may hit a different worker that has no knowledge of the job. Use Redis or a database for shared state.
Not handling training failures: If the training function raises an exception, the job status stays as "running" forever. Wrap the training function in try/except and update the status to "failed" with the error message in the except block.
SSE connections timing out behind reverse proxies: Nginx and other reverse proxies have default timeouts (60s) that close idle SSE connections. Configure proxy_read_timeout 3600s in Nginx or send periodic heartbeat comments (": heartbeat\n\n") to keep the connection alive.
Memory leaks from accumulating job records: Without cleanup, the training_jobs dictionary grows indefinitely. Implement a TTL-based cleanup (delete completed jobs after N hours) or use Redis with EXPIRE to automatically remove stale entries.
Summary
Run training in a background task and expose a GET /status endpoint for polling
Use Server-Sent Events (text/event-stream) for real-time push updates without polling
For production, use Celery with Redis backend and self.update_state() for progress tracking
Store job state in Redis or a database — not in-memory — when running multiple workers
Wrap training code in try/except to update status on failure and implement job cleanup for completed tasks