DRF at Scale: Authentication, Throttling, Versioning, and Caching
Production patterns for Django REST Framework — custom JWT auth, per-endpoint throttling, URL and header versioning, response caching, and query optimization.
DRF works out of the box for simple APIs. As traffic grows and the API surface expands, you need deliberate choices around authentication, rate limiting, versioning, and caching. This article documents the patterns that hold up at scale.
Custom JWT Authentication
rest_framework.authentication.TokenAuthentication is a database-backed scheme — every request hits the authtoken table. For high-traffic APIs, JWT avoids that round-trip.
# authentication.py
import jwt
from django.conf import settings
from django.contrib.auth import get_user_model
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed
User = get_user_model()
class JWTAuthentication(BaseAuthentication):
def authenticate(self, request):
header = request.headers.get("Authorization", "")
if not header.startswith("Bearer "):
return None # try next authenticator
token = header.split(" ", 1)[1]
try:
payload = jwt.decode(
token,
settings.JWT_SECRET,
algorithms=["HS256"],
)
except jwt.ExpiredSignatureError:
raise AuthenticationFailed("Token expired.")
except jwt.InvalidTokenError:
raise AuthenticationFailed("Invalid token.")
try:
user = User.objects.select_related("profile").get(pk=payload["user_id"])
except User.DoesNotExist:
raise AuthenticationFailed("User not found.")
return (user, token) # (user, auth) tuple
# settings.py
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": [
"myapp.authentication.JWTAuthentication",
"rest_framework.authentication.SessionAuthentication", # fallback for browser
],
}
Token Issuance
import jwt
import datetime
from django.conf import settings
def issue_token(user) -> dict:
now = datetime.datetime.utcnow()
access_payload = {
"user_id": user.pk,
"email": user.email,
"iat": now,
"exp": now + datetime.timedelta(minutes=15),
"type": "access",
}
refresh_payload = {
"user_id": user.pk,
"iat": now,
"exp": now + datetime.timedelta(days=30),
"type": "refresh",
}
return {
"access": jwt.encode(access_payload, settings.JWT_SECRET, algorithm="HS256"),
"refresh": jwt.encode(refresh_payload, settings.JWT_SECRET, algorithm="HS256"),
}
Throttling
DRF’s built-in throttling tracks request counts with the cache backend.
# settings.py
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_CLASSES": [
"rest_framework.throttling.AnonRateThrottle",
"rest_framework.throttling.UserRateThrottle",
],
"DEFAULT_THROTTLE_RATES": {
"anon": "60/hour",
"user": "1000/hour",
},
}
Per-View Throttling
from rest_framework.throttling import UserRateThrottle, AnonRateThrottle
class LoginRateThrottle(AnonRateThrottle):
rate = "5/minute" # strict — prevent brute force
scope = "login"
class SearchRateThrottle(UserRateThrottle):
rate = "30/minute"
scope = "search"
class LoginView(APIView):
throttle_classes = [LoginRateThrottle]
permission_classes = [AllowAny]
def post(self, request):
...
class SearchView(generics.ListAPIView):
throttle_classes = [SearchRateThrottle]
...
Custom Throttle: Per-Organization
class OrganizationThrottle(BaseThrottle):
"""Throttle by org plan tier instead of individual user."""
RATES = {"free": 100, "pro": 5000, "enterprise": 50000}
WINDOW = 3600 # 1 hour
def allow_request(self, request, view):
if not request.user.is_authenticated:
return True
org = request.user.organization
cache_key = f"throttle:org:{org.pk}"
count = cache.get(cache_key, 0)
limit = self.RATES.get(org.plan, 100)
if count >= limit:
self.wait_time = self.WINDOW
return False
cache.set(cache_key, count + 1, self.WINDOW)
return True
def wait(self):
return getattr(self, "wait_time", None)
API Versioning
URL Versioning (recommended for public APIs)
# urls.py
from django.urls import path, include
urlpatterns = [
path("api/v1/", include(("api.v1.urls", "v1"), namespace="v1")),
path("api/v2/", include(("api.v2.urls", "v2"), namespace="v2")),
]
# settings.py
REST_FRAMEWORK = {
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning",
"ALLOWED_VERSIONS": ["v1", "v2"],
"DEFAULT_VERSION": "v1",
}
# Use version in a view
class ArticleViewSet(viewsets.ModelViewSet):
def get_serializer_class(self):
if self.request.version == "v2":
return ArticleSerializerV2
return ArticleSerializerV1
Header Versioning (for internal/mobile clients)
# settings.py
REST_FRAMEWORK = {
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.AcceptHeaderVersioning",
}
# Client sends: Accept: application/json; version=2
Namespace Versioning (cleanest for large APIs)
# settings.py
REST_FRAMEWORK = {
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.NamespaceVersioning",
}
# URL: /api/v2/articles/ with app_name="v2" in urls.py
Custom Permissions
from rest_framework.permissions import BasePermission, SAFE_METHODS
class IsOwnerOrAdmin(BasePermission):
"""Object-level: owner can do anything, others get read-only."""
def has_permission(self, request, view):
return request.user.is_authenticated
def has_object_permission(self, request, view, obj):
if request.user.is_staff:
return True
if request.method in SAFE_METHODS:
return True
return obj.owner == request.user
class HasOrganizationPermission(BasePermission):
"""View-level: checks org membership and plan."""
required_plan = "free" # override in subclass
def has_permission(self, request, view):
if not request.user.is_authenticated:
return False
org = getattr(request.user, "organization", None)
if org is None:
return False
plans = ["free", "pro", "enterprise"]
return plans.index(org.plan) >= plans.index(self.required_plan)
class RequireProPlan(HasOrganizationPermission):
required_plan = "pro"
Response Caching
Cache by URL + User
from django.core.cache import cache
from rest_framework.response import Response
import hashlib
class CachedListMixin:
cache_timeout = 300 # 5 minutes
def get_cache_key(self, request):
params = request.query_params.urlencode()
user_id = request.user.pk if request.user.is_authenticated else "anon"
raw = f"{request.path}:{params}:{user_id}"
return f"api_cache:{hashlib.md5(raw.encode()).hexdigest()}"
def list(self, request, *args, **kwargs):
key = self.get_cache_key(request)
cached = cache.get(key)
if cached:
return Response(cached)
response = super().list(request, *args, **kwargs)
cache.set(key, response.data, self.cache_timeout)
return response
class ArticleListView(CachedListMixin, generics.ListAPIView):
queryset = Article.objects.select_related("author").filter(published=True)
serializer_class = ArticleSerializer
cache_timeout = 60 # 1 minute for frequently changing data
Cache Invalidation on Write
class ArticleViewSet(viewsets.ModelViewSet):
def perform_create(self, serializer):
serializer.save(author=self.request.user)
self._invalidate_list_cache()
def perform_update(self, serializer):
serializer.save()
self._invalidate_list_cache()
def _invalidate_list_cache(self):
# Delete all keys matching the list pattern
keys = cache.keys("api_cache:*") # works with django-redis
cache.delete_many(keys)
Serializer Optimization
SerializerMethodField vs Annotated Fields
# Slow — Python loop, one attribute access per object
class ArticleSerializer(serializers.ModelSerializer):
comment_count = serializers.SerializerMethodField()
def get_comment_count(self, obj):
return obj.comments.count() # N+1 if not prefetched!
# Fast — DB does the work, single query
from django.db.models import Count
class ArticleViewSet(viewsets.ModelViewSet):
def get_queryset(self):
return Article.objects.annotate(comment_count=Count("comments"))
class ArticleSerializer(serializers.ModelSerializer):
comment_count = serializers.IntegerField(read_only=True)
Conditional Field Serialization
class ArticleSerializer(serializers.ModelSerializer):
class Meta:
model = Article
fields = ["id", "title", "body", "author", "internal_notes"]
def to_representation(self, instance):
data = super().to_representation(instance)
request = self.context.get("request")
if not (request and request.user.is_staff):
data.pop("internal_notes", None)
return data
Pagination for Large Datasets
from rest_framework.pagination import CursorPagination
class ArticleCursorPagination(CursorPagination):
"""Cursor-based: O(1) regardless of page depth. Use for feeds."""
page_size = 20
ordering = "-created_at"
cursor_query_param = "cursor"
# For admin/search where random access matters — use PageNumber
from rest_framework.pagination import PageNumberPagination
class ArticlePagePagination(PageNumberPagination):
page_size = 20
page_size_query_param = "page_size"
max_page_size = 100
Cursor pagination is O(1) at any page depth and prevents “page drift” (items shifting between pages while the user is browsing). Use it for any high-volume feed. Offset pagination (
LIMIT x OFFSET y) degrades on PostgreSQL past ~100k rows because it scans and discards rows.