Coverage for an_website/utils/token.py: 94.595%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-16 19:56 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""A module providing special auth tokens.""" 

15 

16from __future__ import annotations 

17 

18import hmac 

19import math 

20from base64 import b64decode, b64encode 

21from datetime import datetime 

22from hashlib import blake2b 

23from typing import ClassVar, Literal, NamedTuple, TypeAlias, TypeGuard, get_args 

24 

25from .utils import Permission 

26 

27TokenVersion: TypeAlias = Literal["0"] 

28SUPPORTED_TOKEN_VERSIONS: tuple[TokenVersion, ...] = get_args(TokenVersion) 

29 

30 

31class ParseResult(NamedTuple): 

32 """The class representing a token.""" 

33 

34 token: str 

35 permissions: Permission 

36 valid_until: datetime 

37 salt: bytes 

38 

39 

40class InvalidTokenError(Exception): 

41 """Exception thrown for invalid or expired tokens.""" 

42 

43 

44class InvalidTokenVersionError(InvalidTokenError): 

45 """Exception thrown when the token has an invalid version.""" 

46 

47 SUPPORTED_TOKEN_VERSIONS: ClassVar = SUPPORTED_TOKEN_VERSIONS 

48 

49 

50def is_supported_version(version: str) -> TypeGuard[TokenVersion]: 

51 """Check whether the argument is a supported token version.""" 

52 return version in SUPPORTED_TOKEN_VERSIONS 

53 

54 

55def _split_token(token: str) -> tuple[TokenVersion, str]: 

56 """Split a token into version and the body of the token.""" 

57 if not token: 

58 raise InvalidTokenError() 

59 

60 version = token[0] 

61 if is_supported_version(version): 

62 return version, token[1:] 

63 

64 raise InvalidTokenVersionError() 

65 

66 

67def parse_token( # pylint: disable=inconsistent-return-statements 

68 token: str, 

69 *, 

70 secret: bytes | str, 

71 verify_time: bool = True, 

72) -> ParseResult: 

73 """Parse an auth token.""" 

74 secret = secret.encode("UTF-8") if isinstance(secret, str) else secret 

75 version, token_body = _split_token(token) 

76 try: 

77 if version == "0": 

78 return _parse_token_v0(token_body, secret, verify_time=verify_time) 

79 except InvalidTokenError: 

80 raise 

81 except Exception as exc: 

82 raise InvalidTokenError from exc 

83 

84 

85def create_token( # pylint: disable=too-many-arguments 

86 permissions: Permission, 

87 *, 

88 secret: bytes | str, 

89 duration: int, 

90 start: None | datetime = None, 

91 salt: None | bytes | str = None, 

92 version: TokenVersion = SUPPORTED_TOKEN_VERSIONS[-1], 

93) -> ParseResult: 

94 """Create an auth token.""" 

95 secret = secret.encode("UTF-8") if isinstance(secret, str) else secret 

96 start = datetime.now() if start is None else start 

97 salt = salt.encode("UTF-8") if isinstance(salt, str) else salt or b"" 

98 token: str 

99 if version == "0": 

100 token = _create_token_body_v0( 

101 permissions, secret, duration, start, salt 

102 ) 

103 

104 return parse_token(version + token, secret=secret, verify_time=False) 

105 

106 

107def int_to_bytes(number: int, length: int, signed: bool = False) -> bytes: 

108 """Convert an int to bytes.""" 

109 return number.to_bytes(length, "big", signed=signed) 

110 

111 

112def bytes_to_int(bytes_: bytes, signed: bool = False) -> int: 

113 """Convert an int to bytes.""" 

114 return int.from_bytes(bytes_, "big", signed=signed) 

115 

116 

117def _parse_token_v0( 

118 token_body: str, secret: bytes, *, verify_time: bool = True 

119) -> ParseResult: 

120 """Parse an auth token of version 0.""" 

121 data: bytes = b64decode(token_body) 

122 data, hash_ = data[:-48], data[-48:] 

123 if not hmac.compare_digest(hmac.digest(secret, data, "SHA3-384"), hash_): 

124 raise InvalidTokenError() 

125 data, start = data[:-5], bytes_to_int(data[-5:]) 

126 data, duration = data[:-5], bytes_to_int(data[-5:]) 

127 permissions, salt = bytes_to_int(data[:-6]), data[-6:] 

128 

129 now = int(datetime.now().timestamp()) 

130 if verify_time and (now < start or start + duration < now): 

131 raise InvalidTokenError() 

132 

133 return ParseResult( 

134 "0" + token_body, 

135 Permission(permissions), 

136 datetime.fromtimestamp(start + duration), 

137 salt, 

138 ) 

139 

140 

141def _create_token_body_v0( 

142 permissions: Permission, 

143 secret: bytes, 

144 duration: int, 

145 start: datetime, 

146 salt: bytes, 

147) -> str: 

148 """Create an auth token of version 0.""" 

149 if not salt: 

150 salt = blake2b( 

151 int_to_bytes(int(start.timestamp() - duration), 5), digest_size=6 

152 ).digest() 

153 elif len(salt) < 6: 

154 salt = b"U" * (6 - len(salt)) + salt 

155 elif len(salt) > 6: 

156 salt = salt[:6] 

157 

158 parts = ( 

159 int_to_bytes(permissions, math.ceil(len(Permission) / 8)), 

160 salt, 

161 int_to_bytes(duration, 5), 

162 int_to_bytes(int(start.timestamp()), 5), 

163 ) 

164 data: bytes = b"".join(parts) 

165 

166 len_token = len(data) + 384 // 8 

167 if len_token % 3: 

168 data = int_to_bytes(0, 3 - (len_token % 3)) + data 

169 

170 hash_ = hmac.digest(secret, data, "SHA3-384") 

171 return b64encode(data + hash_).decode("UTF-8")