from __future__ import annotations
import logging
import re
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Set
import jwt
from baseplate import RequestContext
from baseplate.lib import cached_property
from baseplate.lib.edgecontext import EdgeContextFactory as BaseEdgeContextFactory
from baseplate.lib.secrets import SecretsStore
from jwt.algorithms import get_default_algorithms
from thrift import TSerialization
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory
from reddit_edgecontext.thrift.ttypes import Device as TDevice
from reddit_edgecontext.thrift.ttypes import Geolocation as TGeolocation
from reddit_edgecontext.thrift.ttypes import Locale as TLocale
from reddit_edgecontext.thrift.ttypes import Loid as TLoid
from reddit_edgecontext.thrift.ttypes import OriginService as TOriginService
from reddit_edgecontext.thrift.ttypes import Request as TRequest
from reddit_edgecontext.thrift.ttypes import RequestId as TRequestId
from reddit_edgecontext.thrift.ttypes import Session as TSession
logger = logging.getLogger(__name__)
COUNTRY_CODE_RE = re.compile(r"^[A-Z]{2}$")
LOCALE_CODE_RE = re.compile(r"^[a-z]{2,}([_|\-][\da-zA-Z]{2,})*$")
[docs]class NoAuthenticationError(Exception):
"""Raised when trying to use an invalid or missing authentication token."""
class AuthenticationTokenValidator:
"""Factory that knows how to validate raw authentication tokens."""
def __init__(self, secrets: SecretsStore):
self.secrets = secrets
self._algorithm_name = "RS256"
self._algorithm = get_default_algorithms()[self._algorithm_name]
self._cache_mtime = 0.0
self._public_keys: List[Any] = []
def validate(self, token: str) -> AuthenticationToken:
"""Validate a raw authentication token and return an object.
:param token: token value originating from the Authentication service
either directly or from an upstream service
"""
if not token:
return InvalidAuthenticationToken()
secret, mtime = self.secrets.get_versioned_and_mtime("secret/authentication/public-key")
if mtime > self._cache_mtime:
self._public_keys = [self._algorithm.prepare_key(key) for key in secret.all_versions]
self._cache_mtime = mtime
for public_key in self._public_keys:
try:
decoded = jwt.decode(token, public_key, algorithms=[self._algorithm_name])
return ValidatedAuthenticationToken(decoded)
except jwt.ExpiredSignatureError:
return InvalidAuthenticationToken()
except jwt.DecodeError:
pass
return InvalidAuthenticationToken()
[docs]class AuthenticationToken:
"""Information about the authenticated user.
:py:class:`EdgeContext` provides high-level helpers for extracting
data from authentication tokens. Use those instead of direct access through
this class.
"""
@property
def subject(self) -> Optional[str]:
"""Return the raw `subject` that is authenticated."""
raise NotImplementedError
@cached_property
def user_roles(self) -> Set[str]:
raise NotImplementedError
@property
def oauth_client_id(self) -> Optional[str]:
raise NotImplementedError
@property
def oauth_client_type(self) -> Optional[str]:
raise NotImplementedError
@property
def scopes(self) -> Set[str]:
raise NotImplementedError
@property
def loid(self) -> Optional[str]:
raise NotImplementedError
@property
def loid_created_ms(self) -> Optional[int]:
raise NotImplementedError
class ValidatedAuthenticationToken(AuthenticationToken):
def __init__(self, payload: Dict[str, Any]):
self.payload = payload
@property
def subject(self) -> Optional[str]:
return self.payload.get("sub")
@cached_property
def user_roles(self) -> Set[str]:
return set(self.payload.get("roles", []))
@property
def oauth_client_id(self) -> Optional[str]:
return self.payload.get("client_id")
@property
def oauth_client_type(self) -> Optional[str]:
return self.payload.get("client_type")
@property
def scopes(self) -> Set[str]:
return set(self.payload.get("scopes") or [])
@property
def loid(self) -> Optional[str]:
return (self.payload.get("loid") or {}).get("id")
@property
def loid_created_ms(self) -> Optional[int]:
return (self.payload.get("loid") or {}).get("created_ms")
class InvalidAuthenticationToken(AuthenticationToken):
@property
def subject(self) -> Optional[str]:
raise NoAuthenticationError
@cached_property
def user_roles(self) -> Set[str]:
raise NoAuthenticationError
@property
def oauth_client_id(self) -> Optional[str]:
raise NoAuthenticationError
@property
def oauth_client_type(self) -> Optional[str]:
raise NoAuthenticationError
@property
def scopes(self) -> Set[str]:
raise NoAuthenticationError
@property
def loid(self) -> Optional[str]:
raise NoAuthenticationError
@property
def loid_created_ms(self) -> Optional[int]:
raise NoAuthenticationError
[docs]class Session(NamedTuple):
"""Wrapper for the session values in the EdgeContext."""
id: str
"""The ID of the session this request is part of."""
class Device(NamedTuple):
"""Wrapper for the device values in the EdgeContext."""
id: str
"""The Device ID of the client."""
class RequestId(NamedTuple):
"""Wrapper for the request id in the EdgeContext."""
readable_id: str
"""The human readable Request ID of the request."""
@property
def id(self) -> Optional[str]:
return self.readable_id
class OriginService(NamedTuple):
"""Wrapper for the origin values in the EdgeContext."""
name: str
"""The name of the service which created the edge context payload."""
class Geolocation(NamedTuple):
"""Wrapper for the geolocation values in the EdgeContext."""
country_code: str
"""The ISO-3166-1 alpha-2 country code from which the request came."""
class Locale(NamedTuple):
"""Wrapper for the locale values in the EdgeContext."""
locale_code: str
"""IETF language tag representing the preferred locale for the client."""
[docs]class User(NamedTuple):
"""Wrapper for the user values in AuthenticationToken and the LoId cookie."""
authentication_token: AuthenticationToken
"""The authentication provided for the request."""
loid_: str
"""The internal LoID associated with the request, if applicable."""
cookie_created_ms: int
"""When the authentication cookie was created, if applicable."""
@property
def id(self) -> Optional[str]:
"""Return the authenticated account_id for the current User.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token, it was invalid, or the subject is not an
account.
"""
subject = self.authentication_token.subject
if not (subject and subject.startswith("t2_")):
raise NoAuthenticationError
return subject
@property
def is_logged_in(self) -> bool:
"""Return if the User has a valid, authenticated id."""
try:
return self.id is not None
except NoAuthenticationError:
return False
@property
def roles(self) -> Set[str]:
"""Return the authenticated roles for the current User.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token or it was invalid
"""
return self.authentication_token.user_roles
[docs] def has_role(self, role: str) -> bool:
"""Return if the authenticated user has the specified role.
:param client_types: Case-insensitive sequence role name to check.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token defined for the current context
"""
return role.lower() in self.roles
[docs] def event_fields(self) -> Dict[str, Any]:
"""Return fields to be added to events."""
loid: Optional[str] = self.loid
if loid == "":
loid = None
return {
"user_id": loid,
"logged_in": self.is_logged_in,
"cookie_created_timestamp": self.cookie_created_ms,
}
@property
def loid(self) -> str:
"""The LoID associated with the request, if applicable."""
# First, if it's logged in user, return logged in user id.
try:
user_id = self.id
if user_id is not None:
return user_id
except NoAuthenticationError:
pass
# Next, return the loid from thrift payload if it's non-empty
if self.loid_:
return self.loid_
# Finally, return loid from authentication token
try:
loid = self.authentication_token.loid
if loid:
return loid
except NoAuthenticationError:
# self.authentication_token could be an InvalidAuthenticationToken
pass
return ""
[docs]class OAuthClient(NamedTuple):
"""Wrapper for the OAuth2 client values in AuthenticationToken."""
authentication_token: AuthenticationToken
"""The authentication token for this request."""
@property
def id(self) -> Optional[str]:
"""Return the authenticated id for the current client.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token defined for the current context
"""
return self.authentication_token.oauth_client_id
[docs] def is_type(self, *client_types: str) -> bool:
"""Return if the authenticated client type is one of the given types.
When checking the type of the current OauthClient, you should check
that the type "is" one of the allowed types rather than checking that
it "is not" a disallowed type.
For example::
if oauth_client.is_type("third_party"):
...
not::
if not oauth_client.is_type("first_party"):
...
:param client_types: Case-insensitive sequence of client type
names that you want to check.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token defined for the current context
"""
lower_types = (client_type.lower() for client_type in client_types)
if not self.authentication_token.oauth_client_type:
return False
return self.authentication_token.oauth_client_type in lower_types
[docs] def event_fields(self) -> Dict[str, Any]:
"""Return fields to be added to events."""
try:
oauth_client_id = self.id
except NoAuthenticationError:
oauth_client_id = None
return {"oauth_client_id": oauth_client_id}
[docs]class Service(NamedTuple):
"""Wrapper for the Service values in AuthenticationToken."""
authentication_token: AuthenticationToken
"""The authentication token for this request."""
@property
def name(self) -> str:
"""Return the authenticated service name.
:type: name string or None if context authentication is invalid
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token, it was invalid, or the subject is not a
service.
"""
subject = self.authentication_token.subject
if not (subject and subject.startswith("service/")):
raise NoAuthenticationError
name = subject[len("service/") :]
return name
[docs]class EdgeContext:
"""Contextual information about the initial request to an edge service.
Once the :py:class:`~reddit_edgecontext.EdgeContextFactory` is set up, an
instance of this object will be available at ``request.edge_context``.
"""
_HEADER_PROTOCOL_FACTORY = TBinaryProtocolAcceleratedFactory()
def __init__(
self, authn_token_validator: AuthenticationTokenValidator, header: Optional[bytes]
):
self._authn_token_validator = authn_token_validator
self._header = header
[docs] def event_fields(self) -> Dict[str, Any]:
"""Return fields to be added to events."""
fields = {"session_id": self.session.id}
if self.device.id:
fields["device_id"] = self.device.id
if self.request_id.id:
fields["edge_request_id"] = self.request_id.id
fields.update(self.user.event_fields())
fields.update(self.oauth_client.event_fields())
return fields
@cached_property
def authentication_token(self) -> AuthenticationToken:
return self._authn_token_validator.validate(self._t_request.authentication_token)
[docs] @cached_property
def user(self) -> User:
""":py:class:`~reddit_edgecontext.User` object for the current context."""
return User(
authentication_token=self.authentication_token,
loid_=self._t_request.loid.id,
cookie_created_ms=self._t_request.loid.created_ms,
)
[docs] @cached_property
def oauth_client(self) -> OAuthClient:
""":py:class:`~reddit_edgecontext.OAuthClient` object for the current context."""
return OAuthClient(self.authentication_token)
[docs] @cached_property
def device(self) -> Device:
""":py:class:`~reddit_edgecontext.Device` object for the current context."""
return Device(id=self._t_request.device.id)
[docs] @cached_property
def session(self) -> Session:
""":py:class:`~reddit_edgecontext.Session` object for the current context."""
return Session(id=self._t_request.session.id)
[docs] @cached_property
def service(self) -> Service:
""":py:class:`~reddit_edgecontext.Service` object for the current context."""
return Service(self.authentication_token)
[docs] @cached_property
def origin_service(self) -> OriginService:
""":py:class:`~reddit_edgecontext.Origin` object for the current context."""
return OriginService(self._t_request.origin_service.name)
[docs] @cached_property
def geolocation(self) -> Geolocation:
""":py:class:`~reddit_edgecontext.Geolocation` object for the current context."""
return Geolocation(country_code=self._t_request.geolocation.country_code)
[docs] @cached_property
def request_id(self) -> RequestId:
""":py:class:`~reddit_edgecontext.RequestId` object for the current context."""
return RequestId(readable_id=self._t_request.request_id.readable_id)
[docs] @cached_property
def locale(self) -> Locale:
""":py:class:`~reddit_edgecontext.Locale` object for the current context."""
return Locale(
locale_code=self._t_request.locale.locale_code,
)
@cached_property
def _t_request(self) -> TRequest:
_t_request = TRequest()
_t_request.loid = TLoid()
_t_request.session = TSession()
_t_request.device = TDevice()
_t_request.origin_service = TOriginService()
_t_request.geolocation = TGeolocation()
_t_request.request_id = TRequestId()
_t_request.locale = TLocale()
if self._header:
try:
TSerialization.deserialize(_t_request, self._header, self._HEADER_PROTOCOL_FACTORY)
except Exception:
logger.debug("Invalid Edge-Request header. %s", self._header)
return _t_request
[docs] def attach_context(self, context: RequestContext) -> None:
"""Attach this to the provided :py:class:`~baseplate.RequestContext`.
:param context: request context to attach this to
"""
context.edge_context = self
context.raw_edge_context = self._header
[docs]class EdgeContextFactory(BaseEdgeContextFactory):
"""Factory for creating :py:class:`EdgeContext` objects.
Every application should set one of these up. Edge services that talk
directly with clients should use :py:meth:`new` directly. For internal
services, pass the object off to Baseplate's framework integration
(Thrift/Pyramid) for automatic use.
:param baseplate.lib.secrets.SecretsStore secrets: A configured secrets
store.
"""
def __init__(self, secrets: SecretsStore):
self.authn_token_validator = AuthenticationTokenValidator(secrets)
[docs] def new(
self,
authentication_token: Optional[str] = None,
loid_id: Optional[str] = None,
loid_created_ms: Optional[int] = None,
session_id: Optional[str] = None,
device_id: Optional[str] = None,
origin_service_name: Optional[str] = None,
country_code: Optional[str] = None,
request_id: Optional[str] = None,
locale_code: Optional[str] = None,
) -> EdgeContext:
"""Return a new EdgeContext object made from scratch.
Services at the edge that communicate directly with clients should use
this to pass on the information they get to downstream services. They
can then use this information to check authentication, run experiments,
etc.
To use this, create and attach the context early in your request flow:
.. code-block:: python
auth_cookie = request.cookies["authentication"]
token = request.authentication_service.authenticate_cookie(cookie)
loid = parse_loid(request.cookies["loid"])
session = parse_session(request.cookies["session"])
device_id = request.headers["x-device-id"]
request_id = request.headers["x-request-id']
edge_context = self.edgecontext_factory.new(
authentication_token=token,
loid_id=loid.id,
loid_created_ms=loid.created,
session_id=session.id,
device_id=device_id,
request_id=request_id,
)
edge_context.attach_context(request)
:param authentication_token: A raw authentication token as returned by
the authentication service.
:param loid_id: ID for the current LoID in fullname format.
:param loid_created_ms: Epoch milliseconds when the current LoID cookie
was created.
:param session_id: ID for the current session cookie.
:param device_id: ID for the device where the request originated from.
:param origin_service_name: Name for the "origin" service handling the
request from the client.
:param country_code: two-character ISO 3166-1 country code where the
request orginated from.
:param request_id: The human readable form of the unique id assigned to
the underlying request that this EdgeContext represents.
:param locale_code: IETF language tag representing the preferred locale
for the client.
"""
if loid_id is not None and not loid_id.startswith("t2_"):
raise ValueError(
"loid_id <%s> is not in a valid format, it should be in the "
"fullname format with the '0' padding removed: 't2_loid_id'" % loid_id
)
if country_code is not None and not COUNTRY_CODE_RE.match(country_code):
raise ValueError(
"country_code <%s> is not in a valid format, it should be in "
"ISO 3166-1 alpha-2 format: 'US'" % country_code
)
if locale_code is not None and not LOCALE_CODE_RE.match(locale_code):
raise ValueError(
f"locale_code <{locale_code}> is not in a valid format, it should be in "
"IETF language code format – an ISO 639-1 primary language subtag and an"
"optional ISO 3166-1 alpha-2 region subtag separated by an underscore."
"e.g. en_US"
)
t_request = TRequest(
loid=TLoid(id=loid_id, created_ms=loid_created_ms),
session=TSession(id=session_id),
authentication_token=authentication_token,
device=TDevice(id=device_id),
origin_service=TOriginService(name=origin_service_name),
geolocation=TGeolocation(country_code=country_code),
request_id=TRequestId(readable_id=request_id),
locale=TLocale(locale_code=locale_code),
)
header = TSerialization.serialize(t_request, EdgeContext._HEADER_PROTOCOL_FACTORY)
context = EdgeContext(self.authn_token_validator, header)
# Set the _t_request property so we can skip the deserialization step
# since we already have the thrift object.
context._t_request = t_request
return context
[docs] def from_upstream(self, edge_header: Optional[bytes]) -> EdgeContext:
"""Create and return an EdgeContext from an upstream header.
This is generally used internally to Baseplate by framework
integrations that automatically pick up context from inbound requests.
:param edge_header: Raw payload of Edge-Request header from upstream
service.
"""
return EdgeContext(self.authn_token_validator, edge_header)