communex.module.routers.module_routers
1import json 2import re 3from abc import abstractmethod 4from datetime import datetime, timezone 5from functools import partial 6from typing import Any, Protocol, Sequence 7 8import starlette.datastructures 9from fastapi import Request, Response 10from fastapi.responses import JSONResponse 11from fastapi.routing import APIRoute 12from keylimiter import TokenBucketLimiter 13from substrateinterface import Keypair # type: ignore 14 15from communex._common import get_node_url 16from communex.module import _signer as signer 17from communex.module._rate_limiters._stake_limiter import StakeLimiter 18from communex.module._rate_limiters.limiters import (IpLimiterParams, 19 StakeLimiterParams) 20from communex.module._util import (json_error, log, log_reffusal, make_client, 21 try_ss58_decode) 22from communex.types import Ss58Address 23from communex.util.memo import TTLDict 24 25HEX_PATTERN = re.compile(r"^[0-9a-fA-F]+$") 26 27 28def is_hex_string(string: str): 29 return bool(HEX_PATTERN.match(string)) 30 31 32def parse_hex(hex_str: str) -> bytes: 33 if hex_str[0:2] == "0x": 34 return bytes.fromhex(hex_str[2:]) 35 else: 36 return bytes.fromhex(hex_str) 37 38 39class AbstractVerifier(Protocol): 40 @abstractmethod 41 async def verify(self, request: Request) -> JSONResponse | None: 42 """Please dont mutate the request D:""" 43 ... 44 45 46class StakeLimiterVerifier(AbstractVerifier): 47 def __init__( 48 self, 49 subnets_whitelist: list[int] | None, 50 params_: StakeLimiterParams | None, 51 ): 52 self.subnets_whitelist = subnets_whitelist 53 self.params_ = params_ 54 if not self.params_: 55 params = StakeLimiterParams() 56 else: 57 params = self.params_ 58 self.limiter = StakeLimiter( 59 self.subnets_whitelist, 60 epoch=params.epoch, 61 max_cache_age=params.cache_age, 62 get_refill_rate=params.get_refill_per_epoch, 63 ) 64 65 async def verify(self, request: Request): 66 67 if request.client is None: 68 response = JSONResponse( 69 status_code=401, 70 content={ 71 "error": "Address should be present in request" 72 } 73 ) 74 return response 75 76 key = request.headers.get('x-key') 77 if not key: 78 response = JSONResponse( 79 status_code=401, 80 content={"error": "Valid X-Key not provided on headers"} 81 ) 82 return response 83 84 is_allowed = await self.limiter.allow(key) 85 86 if not is_allowed: 87 response = JSONResponse( 88 status_code=429, 89 headers={"X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"}, 90 content={"error": "Rate limit exceeded"} 91 ) 92 return response 93 return None 94 95 96class ListVerifier(AbstractVerifier): 97 def __init__( 98 self, 99 blacklist: list[Ss58Address] | None, 100 whitelist: list[Ss58Address] | None, 101 ip_blacklist: list[str] | None, 102 ): 103 self.blacklist = blacklist 104 self.whitelist = whitelist 105 self.ip_blacklist = ip_blacklist 106 107 async def verify(self, request: Request) -> JSONResponse | None: 108 key = request.headers.get("x-key") 109 if not key: 110 reason = "Missing header: X-Key" 111 log(f"INFO: refusing module request because: {reason}") 112 return json_error(400, "Missing header: X-Key") 113 114 ss58 = try_ss58_decode(key) 115 if ss58 is None: 116 reason = "Caller key could not be decoded into a ss58address" 117 log_reffusal(key, reason) 118 return json_error(400, reason) 119 if request.client is None: 120 return json_error(400, "Address should be present in request") 121 if self.blacklist and ss58 in self.blacklist: 122 return json_error(403, "You are blacklisted") 123 if self.ip_blacklist and request.client.host in self.ip_blacklist: 124 return json_error(403, "Your IP is blacklisted") 125 if self.whitelist and ss58 not in self.whitelist: 126 return json_error(403, "You are not whitelisted") 127 return None 128 129 130class IpLimiterVerifier(AbstractVerifier): 131 def __init__( 132 self, 133 params: IpLimiterParams | None, 134 ): 135 ''' 136 :param limiter: KeyLimiter instance OR None 137 138 If limiter is None, then a default TokenBucketLimiter is used with the following config: 139 bucket_size=200, refill_rate=15 140 ''' 141 142 # fallback to default limiter 143 if not params: 144 params = IpLimiterParams() 145 self._limiter = TokenBucketLimiter( 146 bucket_size=params.bucket_size, 147 refill_rate=params.refill_rate 148 ) 149 150 async def verify(self, request: Request): 151 assert request.client is not None, "request is invalid" 152 assert request.client.host, "request is invalid." 153 154 ip = request.client.host 155 156 is_allowed = self._limiter.allow(ip) 157 158 if not is_allowed: 159 response = JSONResponse( 160 status_code=429, 161 headers={"X-RateLimit-Remaining": str(self._limiter.remaining(ip))}, 162 content={"error": "Rate limit exceeded"} 163 ) 164 return response 165 return None 166 167 168class InputHandlerVerifier(AbstractVerifier): 169 def __init__( 170 self, 171 subnets_whitelist: list[int] | None, 172 module_key: Ss58Address, 173 request_staleness: int, 174 blockchain_cache: TTLDict[str, list[Ss58Address]], 175 host_key: Keypair, 176 use_testnet: bool, 177 ): 178 self.subnets_whitelist = subnets_whitelist 179 self.module_key = module_key 180 self.request_staleness = request_staleness 181 self.blockchain_cache = blockchain_cache 182 self.host_key = host_key 183 self.use_testnet = use_testnet 184 185 async def verify(self, request: Request): 186 body = await request.body() 187 188 # TODO: we'll replace this by a Result ADT :) 189 match self._check_inputs(request, body, self.module_key): 190 case (False, error): 191 return error 192 case (True, _): 193 pass 194 195 body_dict: dict[str, dict[str, Any]] = json.loads(body) 196 timestamp = body_dict['params'].get("timestamp", None) 197 legacy_timestamp = request.headers.get("X-Timestamp", None) 198 try: 199 timestamp_to_use = timestamp if not legacy_timestamp else legacy_timestamp 200 request_time = datetime.fromisoformat(timestamp_to_use) 201 except Exception: 202 return JSONResponse(status_code=400, content={"error": "Invalid ISO timestamp given"}) 203 if (datetime.now(timezone.utc) - request_time).total_seconds() > self.request_staleness: 204 return JSONResponse(status_code=400, content={"error": "Request is too stale"}) 205 return None 206 207 def _check_inputs( 208 self, request: Request, 209 body: bytes, 210 module_key: Ss58Address 211 ): 212 required_headers = ["x-signature", "x-key", "x-crypto"] 213 optional_headers = ["x-timestamp"] 214 215 # TODO: we'll replace this by a Result ADT :) 216 match self._get_headers_dict(request.headers, required_headers, optional_headers): 217 case (False, error): 218 return (False, error) 219 case (True, headers_dict): 220 pass 221 222 # TODO: we'll replace this by a Result ADT :) 223 match self._check_signature(headers_dict, body, module_key): 224 case (False, error): 225 return (False, error) 226 case (True, _): 227 pass 228 229 # TODO: we'll replace this by a Result ADT :) 230 match self._check_key_registered( 231 self.subnets_whitelist, 232 headers_dict, 233 self.blockchain_cache, 234 self.host_key, 235 self.use_testnet, 236 ): 237 case (False, error): 238 return (False, error) 239 case (True, _): 240 pass 241 242 return (True, None) 243 244 def _get_headers_dict( 245 self, 246 headers: starlette.datastructures.Headers, 247 required: list[str], 248 optional: list[str], 249 ): 250 headers_dict: dict[str, str] = {} 251 for required_header in required: 252 value = headers.get(required_header) 253 if not value: 254 code = 400 255 return False, json_error(code, f"Missing header: {required_header}") 256 headers_dict[required_header] = value 257 for optional_header in optional: 258 value = headers.get(optional_header) 259 if value: 260 headers_dict[optional_header] = value 261 262 return True, headers_dict 263 264 def _check_signature( 265 self, 266 headers_dict: dict[str, str], 267 body: bytes, 268 module_key: Ss58Address 269 ): 270 key = headers_dict["x-key"] 271 signature = headers_dict["x-signature"] 272 crypto = int(headers_dict["x-crypto"]) # TODO: better handling of this 273 274 if not is_hex_string(key): 275 reason = "X-Key should be a hex value" 276 log_reffusal(key, reason) 277 return (False, json_error(400, reason)) 278 try: 279 signature = parse_hex(signature) 280 except Exception: 281 reason = "Signature sent is not a valid hex value" 282 log_reffusal(key, reason) 283 return False, json_error(400, reason) 284 try: 285 key = parse_hex(key) 286 except Exception: 287 reason = "Key sent is not a valid hex value" 288 log_reffusal(key, reason) 289 return False, json_error(400, reason) 290 # decodes the key for better logging 291 key_ss58 = try_ss58_decode(key) 292 if key_ss58 is None: 293 reason = "Caller key could not be decoded into a ss58address" 294 log_reffusal(key.decode(), reason) 295 return (False, json_error(400, reason)) 296 297 timestamp = headers_dict.get("x-timestamp") 298 legacy_verified = False 299 if timestamp: 300 # tries to do a legacy verification 301 json_body = json.loads(body) 302 json_body["timestamp"] = timestamp 303 stamped_body = json.dumps(json_body).encode() 304 legacy_verified = signer.verify(key, crypto, stamped_body, signature) 305 306 verified = signer.verify(key, crypto, body, signature) 307 if not verified and not legacy_verified: 308 reason = "Signature doesn't match" 309 log_reffusal(key_ss58, reason) 310 return (False, json_error(401, "Signatures doesn't match")) 311 312 body_dict: dict[str, dict[str, Any]] = json.loads(body) 313 target_key = body_dict['params'].get("target_key", None) 314 if not target_key or target_key != module_key: 315 reason = "Wrong target_key in body" 316 log_reffusal(key_ss58, reason) 317 return (False, json_error(401, "Wrong target_key in body")) 318 319 return (True, None) 320 321 def _check_key_registered( 322 self, 323 subnets_whitelist: list[int] | None, 324 headers_dict: dict[str, str], 325 blockchain_cache: TTLDict[str, list[Ss58Address]], 326 host_key: Keypair, 327 use_testnet: bool, 328 ): 329 key = headers_dict["x-key"] 330 if not is_hex_string(key): 331 return (False, json_error(400, "X-Key should be a hex value")) 332 key = parse_hex(key) 333 334 # TODO: checking for key being registered should be smarter 335 # e.g. query and store all registered modules periodically. 336 337 ss58 = try_ss58_decode(key) 338 if ss58 is None: 339 reason = "Caller key could not be decoded into a ss58address" 340 log_reffusal(key.decode(), reason) 341 return (False, json_error(400, reason)) 342 343 # If subnets whitelist is specified, checks if key is registered in one 344 # of the given subnets 345 346 allowed_subnets: dict[int, bool] = {} 347 caller_subnets: list[int] = [] 348 if subnets_whitelist is not None: 349 def query_keys(subnet: int): 350 try: 351 node_url = get_node_url(None, use_testnet=use_testnet) 352 client = make_client(node_url) # TODO: get client from outer context 353 return [*client.query_map_key(subnet).values()] 354 except Exception: 355 log( 356 "WARNING: Could not connect to a blockchain node" 357 ) 358 return_list: list[Ss58Address] = [] 359 return return_list 360 361 # TODO: client pool for entire module server 362 363 got_keys = False 364 no_keys_reason = ( 365 "Miner could not connect to a blockchain node " 366 "or there is no key registered on the subnet(s) {} " 367 ) 368 for subnet in subnets_whitelist: 369 get_keys_on_subnet = partial(query_keys, subnet) 370 cache_key = f"keys_on_subnet_{subnet}" 371 keys_on_subnet = blockchain_cache.get_or_insert_lazy( 372 cache_key, get_keys_on_subnet 373 ) 374 if len(keys_on_subnet) == 0: 375 reason = no_keys_reason.format(subnet) 376 log(f"WARNING: {reason}") 377 else: 378 got_keys = True 379 if host_key.ss58_address not in keys_on_subnet: 380 log( 381 f"WARNING: This miner is deregistered on subnet {subnet}" 382 ) 383 else: 384 allowed_subnets[subnet] = True 385 if ss58 in keys_on_subnet: 386 caller_subnets.append(subnet) 387 if not got_keys: 388 return False, json_error(503, no_keys_reason.format(subnets_whitelist)) 389 if not allowed_subnets: 390 log("WARNING: Miner is not registered on any subnet") 391 return False, json_error(403, "Miner is not registered on any subnet") 392 393 # searches for a common subnet between caller and miner 394 # TODO: use sets 395 allowed_subnets = { 396 subnet: allowed for subnet, allowed in allowed_subnets.items() if ( 397 subnet in caller_subnets 398 ) 399 } 400 if not allowed_subnets: 401 reason = "Caller key is not registered in any subnet that the miner is" 402 log_reffusal(ss58, reason) 403 return False, json_error( 404 403, reason 405 ) 406 else: 407 # accepts everything 408 pass 409 410 return (True, None) 411 412 413def build_route_class( 414 verifiers: Sequence[AbstractVerifier] 415) -> type[APIRoute]: 416 417 class CheckListsRoute(APIRoute): 418 def get_route_handler(self): 419 original_route_handler = super().get_route_handler() 420 421 async def custom_route_handler(request: Request) -> Response | JSONResponse: 422 if not request.url.path.startswith('/method'): 423 unhandled_response: Response = await original_route_handler(request) 424 return unhandled_response 425 for verifier in verifiers: 426 response = await verifier.verify(request) 427 if response is not None: 428 return response 429 430 original_response: Response = await original_route_handler(request) 431 return original_response 432 433 return custom_route_handler 434 435 return CheckListsRoute
40class AbstractVerifier(Protocol): 41 @abstractmethod 42 async def verify(self, request: Request) -> JSONResponse | None: 43 """Please dont mutate the request D:""" 44 ...
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
47class StakeLimiterVerifier(AbstractVerifier): 48 def __init__( 49 self, 50 subnets_whitelist: list[int] | None, 51 params_: StakeLimiterParams | None, 52 ): 53 self.subnets_whitelist = subnets_whitelist 54 self.params_ = params_ 55 if not self.params_: 56 params = StakeLimiterParams() 57 else: 58 params = self.params_ 59 self.limiter = StakeLimiter( 60 self.subnets_whitelist, 61 epoch=params.epoch, 62 max_cache_age=params.cache_age, 63 get_refill_rate=params.get_refill_per_epoch, 64 ) 65 66 async def verify(self, request: Request): 67 68 if request.client is None: 69 response = JSONResponse( 70 status_code=401, 71 content={ 72 "error": "Address should be present in request" 73 } 74 ) 75 return response 76 77 key = request.headers.get('x-key') 78 if not key: 79 response = JSONResponse( 80 status_code=401, 81 content={"error": "Valid X-Key not provided on headers"} 82 ) 83 return response 84 85 is_allowed = await self.limiter.allow(key) 86 87 if not is_allowed: 88 response = JSONResponse( 89 status_code=429, 90 headers={"X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"}, 91 content={"error": "Rate limit exceeded"} 92 ) 93 return response 94 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
48 def __init__( 49 self, 50 subnets_whitelist: list[int] | None, 51 params_: StakeLimiterParams | None, 52 ): 53 self.subnets_whitelist = subnets_whitelist 54 self.params_ = params_ 55 if not self.params_: 56 params = StakeLimiterParams() 57 else: 58 params = self.params_ 59 self.limiter = StakeLimiter( 60 self.subnets_whitelist, 61 epoch=params.epoch, 62 max_cache_age=params.cache_age, 63 get_refill_rate=params.get_refill_per_epoch, 64 )
66 async def verify(self, request: Request): 67 68 if request.client is None: 69 response = JSONResponse( 70 status_code=401, 71 content={ 72 "error": "Address should be present in request" 73 } 74 ) 75 return response 76 77 key = request.headers.get('x-key') 78 if not key: 79 response = JSONResponse( 80 status_code=401, 81 content={"error": "Valid X-Key not provided on headers"} 82 ) 83 return response 84 85 is_allowed = await self.limiter.allow(key) 86 87 if not is_allowed: 88 response = JSONResponse( 89 status_code=429, 90 headers={"X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"}, 91 content={"error": "Rate limit exceeded"} 92 ) 93 return response 94 return None
Please dont mutate the request D:
97class ListVerifier(AbstractVerifier): 98 def __init__( 99 self, 100 blacklist: list[Ss58Address] | None, 101 whitelist: list[Ss58Address] | None, 102 ip_blacklist: list[str] | None, 103 ): 104 self.blacklist = blacklist 105 self.whitelist = whitelist 106 self.ip_blacklist = ip_blacklist 107 108 async def verify(self, request: Request) -> JSONResponse | None: 109 key = request.headers.get("x-key") 110 if not key: 111 reason = "Missing header: X-Key" 112 log(f"INFO: refusing module request because: {reason}") 113 return json_error(400, "Missing header: X-Key") 114 115 ss58 = try_ss58_decode(key) 116 if ss58 is None: 117 reason = "Caller key could not be decoded into a ss58address" 118 log_reffusal(key, reason) 119 return json_error(400, reason) 120 if request.client is None: 121 return json_error(400, "Address should be present in request") 122 if self.blacklist and ss58 in self.blacklist: 123 return json_error(403, "You are blacklisted") 124 if self.ip_blacklist and request.client.host in self.ip_blacklist: 125 return json_error(403, "Your IP is blacklisted") 126 if self.whitelist and ss58 not in self.whitelist: 127 return json_error(403, "You are not whitelisted") 128 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
108 async def verify(self, request: Request) -> JSONResponse | None: 109 key = request.headers.get("x-key") 110 if not key: 111 reason = "Missing header: X-Key" 112 log(f"INFO: refusing module request because: {reason}") 113 return json_error(400, "Missing header: X-Key") 114 115 ss58 = try_ss58_decode(key) 116 if ss58 is None: 117 reason = "Caller key could not be decoded into a ss58address" 118 log_reffusal(key, reason) 119 return json_error(400, reason) 120 if request.client is None: 121 return json_error(400, "Address should be present in request") 122 if self.blacklist and ss58 in self.blacklist: 123 return json_error(403, "You are blacklisted") 124 if self.ip_blacklist and request.client.host in self.ip_blacklist: 125 return json_error(403, "Your IP is blacklisted") 126 if self.whitelist and ss58 not in self.whitelist: 127 return json_error(403, "You are not whitelisted") 128 return None
Please dont mutate the request D:
131class IpLimiterVerifier(AbstractVerifier): 132 def __init__( 133 self, 134 params: IpLimiterParams | None, 135 ): 136 ''' 137 :param limiter: KeyLimiter instance OR None 138 139 If limiter is None, then a default TokenBucketLimiter is used with the following config: 140 bucket_size=200, refill_rate=15 141 ''' 142 143 # fallback to default limiter 144 if not params: 145 params = IpLimiterParams() 146 self._limiter = TokenBucketLimiter( 147 bucket_size=params.bucket_size, 148 refill_rate=params.refill_rate 149 ) 150 151 async def verify(self, request: Request): 152 assert request.client is not None, "request is invalid" 153 assert request.client.host, "request is invalid." 154 155 ip = request.client.host 156 157 is_allowed = self._limiter.allow(ip) 158 159 if not is_allowed: 160 response = JSONResponse( 161 status_code=429, 162 headers={"X-RateLimit-Remaining": str(self._limiter.remaining(ip))}, 163 content={"error": "Rate limit exceeded"} 164 ) 165 return response 166 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
132 def __init__( 133 self, 134 params: IpLimiterParams | None, 135 ): 136 ''' 137 :param limiter: KeyLimiter instance OR None 138 139 If limiter is None, then a default TokenBucketLimiter is used with the following config: 140 bucket_size=200, refill_rate=15 141 ''' 142 143 # fallback to default limiter 144 if not params: 145 params = IpLimiterParams() 146 self._limiter = TokenBucketLimiter( 147 bucket_size=params.bucket_size, 148 refill_rate=params.refill_rate 149 )
Parameters
- limiter: KeyLimiter instance OR None
If limiter is None, then a default TokenBucketLimiter is used with the following config: bucket_size=200, refill_rate=15
151 async def verify(self, request: Request): 152 assert request.client is not None, "request is invalid" 153 assert request.client.host, "request is invalid." 154 155 ip = request.client.host 156 157 is_allowed = self._limiter.allow(ip) 158 159 if not is_allowed: 160 response = JSONResponse( 161 status_code=429, 162 headers={"X-RateLimit-Remaining": str(self._limiter.remaining(ip))}, 163 content={"error": "Rate limit exceeded"} 164 ) 165 return response 166 return None
Please dont mutate the request D:
169class InputHandlerVerifier(AbstractVerifier): 170 def __init__( 171 self, 172 subnets_whitelist: list[int] | None, 173 module_key: Ss58Address, 174 request_staleness: int, 175 blockchain_cache: TTLDict[str, list[Ss58Address]], 176 host_key: Keypair, 177 use_testnet: bool, 178 ): 179 self.subnets_whitelist = subnets_whitelist 180 self.module_key = module_key 181 self.request_staleness = request_staleness 182 self.blockchain_cache = blockchain_cache 183 self.host_key = host_key 184 self.use_testnet = use_testnet 185 186 async def verify(self, request: Request): 187 body = await request.body() 188 189 # TODO: we'll replace this by a Result ADT :) 190 match self._check_inputs(request, body, self.module_key): 191 case (False, error): 192 return error 193 case (True, _): 194 pass 195 196 body_dict: dict[str, dict[str, Any]] = json.loads(body) 197 timestamp = body_dict['params'].get("timestamp", None) 198 legacy_timestamp = request.headers.get("X-Timestamp", None) 199 try: 200 timestamp_to_use = timestamp if not legacy_timestamp else legacy_timestamp 201 request_time = datetime.fromisoformat(timestamp_to_use) 202 except Exception: 203 return JSONResponse(status_code=400, content={"error": "Invalid ISO timestamp given"}) 204 if (datetime.now(timezone.utc) - request_time).total_seconds() > self.request_staleness: 205 return JSONResponse(status_code=400, content={"error": "Request is too stale"}) 206 return None 207 208 def _check_inputs( 209 self, request: Request, 210 body: bytes, 211 module_key: Ss58Address 212 ): 213 required_headers = ["x-signature", "x-key", "x-crypto"] 214 optional_headers = ["x-timestamp"] 215 216 # TODO: we'll replace this by a Result ADT :) 217 match self._get_headers_dict(request.headers, required_headers, optional_headers): 218 case (False, error): 219 return (False, error) 220 case (True, headers_dict): 221 pass 222 223 # TODO: we'll replace this by a Result ADT :) 224 match self._check_signature(headers_dict, body, module_key): 225 case (False, error): 226 return (False, error) 227 case (True, _): 228 pass 229 230 # TODO: we'll replace this by a Result ADT :) 231 match self._check_key_registered( 232 self.subnets_whitelist, 233 headers_dict, 234 self.blockchain_cache, 235 self.host_key, 236 self.use_testnet, 237 ): 238 case (False, error): 239 return (False, error) 240 case (True, _): 241 pass 242 243 return (True, None) 244 245 def _get_headers_dict( 246 self, 247 headers: starlette.datastructures.Headers, 248 required: list[str], 249 optional: list[str], 250 ): 251 headers_dict: dict[str, str] = {} 252 for required_header in required: 253 value = headers.get(required_header) 254 if not value: 255 code = 400 256 return False, json_error(code, f"Missing header: {required_header}") 257 headers_dict[required_header] = value 258 for optional_header in optional: 259 value = headers.get(optional_header) 260 if value: 261 headers_dict[optional_header] = value 262 263 return True, headers_dict 264 265 def _check_signature( 266 self, 267 headers_dict: dict[str, str], 268 body: bytes, 269 module_key: Ss58Address 270 ): 271 key = headers_dict["x-key"] 272 signature = headers_dict["x-signature"] 273 crypto = int(headers_dict["x-crypto"]) # TODO: better handling of this 274 275 if not is_hex_string(key): 276 reason = "X-Key should be a hex value" 277 log_reffusal(key, reason) 278 return (False, json_error(400, reason)) 279 try: 280 signature = parse_hex(signature) 281 except Exception: 282 reason = "Signature sent is not a valid hex value" 283 log_reffusal(key, reason) 284 return False, json_error(400, reason) 285 try: 286 key = parse_hex(key) 287 except Exception: 288 reason = "Key sent is not a valid hex value" 289 log_reffusal(key, reason) 290 return False, json_error(400, reason) 291 # decodes the key for better logging 292 key_ss58 = try_ss58_decode(key) 293 if key_ss58 is None: 294 reason = "Caller key could not be decoded into a ss58address" 295 log_reffusal(key.decode(), reason) 296 return (False, json_error(400, reason)) 297 298 timestamp = headers_dict.get("x-timestamp") 299 legacy_verified = False 300 if timestamp: 301 # tries to do a legacy verification 302 json_body = json.loads(body) 303 json_body["timestamp"] = timestamp 304 stamped_body = json.dumps(json_body).encode() 305 legacy_verified = signer.verify(key, crypto, stamped_body, signature) 306 307 verified = signer.verify(key, crypto, body, signature) 308 if not verified and not legacy_verified: 309 reason = "Signature doesn't match" 310 log_reffusal(key_ss58, reason) 311 return (False, json_error(401, "Signatures doesn't match")) 312 313 body_dict: dict[str, dict[str, Any]] = json.loads(body) 314 target_key = body_dict['params'].get("target_key", None) 315 if not target_key or target_key != module_key: 316 reason = "Wrong target_key in body" 317 log_reffusal(key_ss58, reason) 318 return (False, json_error(401, "Wrong target_key in body")) 319 320 return (True, None) 321 322 def _check_key_registered( 323 self, 324 subnets_whitelist: list[int] | None, 325 headers_dict: dict[str, str], 326 blockchain_cache: TTLDict[str, list[Ss58Address]], 327 host_key: Keypair, 328 use_testnet: bool, 329 ): 330 key = headers_dict["x-key"] 331 if not is_hex_string(key): 332 return (False, json_error(400, "X-Key should be a hex value")) 333 key = parse_hex(key) 334 335 # TODO: checking for key being registered should be smarter 336 # e.g. query and store all registered modules periodically. 337 338 ss58 = try_ss58_decode(key) 339 if ss58 is None: 340 reason = "Caller key could not be decoded into a ss58address" 341 log_reffusal(key.decode(), reason) 342 return (False, json_error(400, reason)) 343 344 # If subnets whitelist is specified, checks if key is registered in one 345 # of the given subnets 346 347 allowed_subnets: dict[int, bool] = {} 348 caller_subnets: list[int] = [] 349 if subnets_whitelist is not None: 350 def query_keys(subnet: int): 351 try: 352 node_url = get_node_url(None, use_testnet=use_testnet) 353 client = make_client(node_url) # TODO: get client from outer context 354 return [*client.query_map_key(subnet).values()] 355 except Exception: 356 log( 357 "WARNING: Could not connect to a blockchain node" 358 ) 359 return_list: list[Ss58Address] = [] 360 return return_list 361 362 # TODO: client pool for entire module server 363 364 got_keys = False 365 no_keys_reason = ( 366 "Miner could not connect to a blockchain node " 367 "or there is no key registered on the subnet(s) {} " 368 ) 369 for subnet in subnets_whitelist: 370 get_keys_on_subnet = partial(query_keys, subnet) 371 cache_key = f"keys_on_subnet_{subnet}" 372 keys_on_subnet = blockchain_cache.get_or_insert_lazy( 373 cache_key, get_keys_on_subnet 374 ) 375 if len(keys_on_subnet) == 0: 376 reason = no_keys_reason.format(subnet) 377 log(f"WARNING: {reason}") 378 else: 379 got_keys = True 380 if host_key.ss58_address not in keys_on_subnet: 381 log( 382 f"WARNING: This miner is deregistered on subnet {subnet}" 383 ) 384 else: 385 allowed_subnets[subnet] = True 386 if ss58 in keys_on_subnet: 387 caller_subnets.append(subnet) 388 if not got_keys: 389 return False, json_error(503, no_keys_reason.format(subnets_whitelist)) 390 if not allowed_subnets: 391 log("WARNING: Miner is not registered on any subnet") 392 return False, json_error(403, "Miner is not registered on any subnet") 393 394 # searches for a common subnet between caller and miner 395 # TODO: use sets 396 allowed_subnets = { 397 subnet: allowed for subnet, allowed in allowed_subnets.items() if ( 398 subnet in caller_subnets 399 ) 400 } 401 if not allowed_subnets: 402 reason = "Caller key is not registered in any subnet that the miner is" 403 log_reffusal(ss58, reason) 404 return False, json_error( 405 403, reason 406 ) 407 else: 408 # accepts everything 409 pass 410 411 return (True, None)
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
170 def __init__( 171 self, 172 subnets_whitelist: list[int] | None, 173 module_key: Ss58Address, 174 request_staleness: int, 175 blockchain_cache: TTLDict[str, list[Ss58Address]], 176 host_key: Keypair, 177 use_testnet: bool, 178 ): 179 self.subnets_whitelist = subnets_whitelist 180 self.module_key = module_key 181 self.request_staleness = request_staleness 182 self.blockchain_cache = blockchain_cache 183 self.host_key = host_key 184 self.use_testnet = use_testnet
186 async def verify(self, request: Request): 187 body = await request.body() 188 189 # TODO: we'll replace this by a Result ADT :) 190 match self._check_inputs(request, body, self.module_key): 191 case (False, error): 192 return error 193 case (True, _): 194 pass 195 196 body_dict: dict[str, dict[str, Any]] = json.loads(body) 197 timestamp = body_dict['params'].get("timestamp", None) 198 legacy_timestamp = request.headers.get("X-Timestamp", None) 199 try: 200 timestamp_to_use = timestamp if not legacy_timestamp else legacy_timestamp 201 request_time = datetime.fromisoformat(timestamp_to_use) 202 except Exception: 203 return JSONResponse(status_code=400, content={"error": "Invalid ISO timestamp given"}) 204 if (datetime.now(timezone.utc) - request_time).total_seconds() > self.request_staleness: 205 return JSONResponse(status_code=400, content={"error": "Request is too stale"}) 206 return None
Please dont mutate the request D:
414def build_route_class( 415 verifiers: Sequence[AbstractVerifier] 416) -> type[APIRoute]: 417 418 class CheckListsRoute(APIRoute): 419 def get_route_handler(self): 420 original_route_handler = super().get_route_handler() 421 422 async def custom_route_handler(request: Request) -> Response | JSONResponse: 423 if not request.url.path.startswith('/method'): 424 unhandled_response: Response = await original_route_handler(request) 425 return unhandled_response 426 for verifier in verifiers: 427 response = await verifier.verify(request) 428 if response is not None: 429 return response 430 431 original_response: Response = await original_route_handler(request) 432 return original_response 433 434 return custom_route_handler 435 436 return CheckListsRoute