communex.module.routers.module_routers
1import json 2import re 3from abc import abstractmethod 4from datetime import datetime, timezone 5from functools import partial 6from typing import Any, Protocol, Sequence 7 8import starlette.datastructures 9from fastapi import Request, Response 10from fastapi.responses import JSONResponse 11from fastapi.routing import APIRoute 12from keylimiter import TokenBucketLimiter 13from substrateinterface import Keypair 14 15from communex._common import get_node_url 16from communex.module import _signer as signer 17from communex.module._rate_limiters._stake_limiter import StakeLimiter 18from communex.module._rate_limiters.limiters import ( 19 IpLimiterParams, 20 StakeLimiterParams, 21) 22from communex.module._util import ( 23 json_error, 24 log, 25 log_reffusal, 26 make_client, 27 try_ss58_decode, 28) 29from communex.types import Ss58Address 30from communex.util.memo import TTLDict 31 32HEX_PATTERN = re.compile(r"^[0-9a-fA-F]+$") 33 34 35def is_hex_string(string: str): 36 return bool(HEX_PATTERN.match(string)) 37 38 39def parse_hex(hex_str: str) -> bytes: 40 if hex_str[0:2] == "0x": 41 return bytes.fromhex(hex_str[2:]) 42 else: 43 return bytes.fromhex(hex_str) 44 45 46class AbstractVerifier(Protocol): 47 @abstractmethod 48 async def verify(self, request: Request) -> JSONResponse | None: 49 """Please dont mutate the request D:""" 50 ... 51 52 53class StakeLimiterVerifier(AbstractVerifier): 54 def __init__( 55 self, 56 subnets_whitelist: list[int] | None, 57 params_: StakeLimiterParams | None, 58 ): 59 self.subnets_whitelist = subnets_whitelist 60 self.params_ = params_ 61 if not self.params_: 62 params = StakeLimiterParams() 63 else: 64 params = self.params_ 65 self.limiter = StakeLimiter( 66 self.subnets_whitelist, 67 epoch=params.epoch, 68 max_cache_age=params.cache_age, 69 get_refill_rate=params.get_refill_per_epoch, 70 ) 71 72 async def verify(self, request: Request): 73 if request.client is None: 74 response = JSONResponse( 75 status_code=401, 76 content={"error": "Address should be present in request"}, 77 ) 78 return response 79 80 key = request.headers.get("x-key") 81 if not key: 82 response = JSONResponse( 83 status_code=401, 84 content={"error": "Valid X-Key not provided on headers"}, 85 ) 86 return response 87 88 is_allowed = await self.limiter.allow(key) 89 90 if not is_allowed: 91 response = JSONResponse( 92 status_code=429, 93 headers={ 94 "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds" 95 }, 96 content={"error": "Rate limit exceeded"}, 97 ) 98 return response 99 return None 100 101 102class ListVerifier(AbstractVerifier): 103 def __init__( 104 self, 105 blacklist: list[Ss58Address] | None, 106 whitelist: list[Ss58Address] | None, 107 ip_blacklist: list[str] | None, 108 ): 109 self.blacklist = blacklist 110 self.whitelist = whitelist 111 self.ip_blacklist = ip_blacklist 112 113 async def verify(self, request: Request) -> JSONResponse | None: 114 key = request.headers.get("x-key") 115 if not key: 116 reason = "Missing header: X-Key" 117 log(f"INFO: refusing module request because: {reason}") 118 return json_error(400, "Missing header: X-Key") 119 120 ss58 = try_ss58_decode(key) 121 if ss58 is None: 122 reason = "Caller key could not be decoded into a ss58address" 123 log_reffusal(key, reason) 124 return json_error(400, reason) 125 if request.client is None: 126 return json_error(400, "Address should be present in request") 127 if self.blacklist and ss58 in self.blacklist: 128 return json_error(403, "You are blacklisted") 129 if self.ip_blacklist and request.client.host in self.ip_blacklist: 130 return json_error(403, "Your IP is blacklisted") 131 if self.whitelist and ss58 not in self.whitelist: 132 return json_error(403, "You are not whitelisted") 133 return None 134 135 136class IpLimiterVerifier(AbstractVerifier): 137 def __init__( 138 self, 139 params: IpLimiterParams | None, 140 ): 141 """ 142 :param limiter: KeyLimiter instance OR None 143 144 If limiter is None, then a default TokenBucketLimiter is used with the following config: 145 bucket_size=200, refill_rate=15 146 """ 147 148 # fallback to default limiter 149 if not params: 150 params = IpLimiterParams() 151 self._limiter = TokenBucketLimiter( 152 bucket_size=params.bucket_size, refill_rate=params.refill_rate 153 ) 154 155 async def verify(self, request: Request): 156 assert request.client is not None, "request is invalid" 157 assert request.client.host, "request is invalid." 158 159 ip = request.client.host 160 161 is_allowed = self._limiter.allow(ip) 162 163 if not is_allowed: 164 response = JSONResponse( 165 status_code=429, 166 headers={ 167 "X-RateLimit-Remaining": str(self._limiter.remaining(ip)) 168 }, 169 content={"error": "Rate limit exceeded"}, 170 ) 171 return response 172 return None 173 174 175class InputHandlerVerifier(AbstractVerifier): 176 def __init__( 177 self, 178 subnets_whitelist: list[int] | None, 179 module_key: Ss58Address, 180 request_staleness: int, 181 blockchain_cache: TTLDict[str, list[Ss58Address]], 182 host_key: Keypair, 183 use_testnet: bool, 184 ): 185 self.subnets_whitelist = subnets_whitelist 186 self.module_key = module_key 187 self.request_staleness = request_staleness 188 self.blockchain_cache = blockchain_cache 189 self.host_key = host_key 190 self.use_testnet = use_testnet 191 192 async def verify(self, request: Request): 193 body = await request.body() 194 195 # TODO: we'll replace this by a Result ADT :) 196 match self._check_inputs(request, body, self.module_key): 197 case (False, error): 198 return error 199 case (True, _): 200 pass 201 202 body_dict: dict[str, dict[str, Any]] = json.loads(body) 203 timestamp = body_dict["params"].get("timestamp", None) 204 legacy_timestamp = request.headers.get("X-Timestamp", None) 205 try: 206 timestamp_to_use = ( 207 timestamp if not legacy_timestamp else legacy_timestamp 208 ) 209 request_time = datetime.fromisoformat(timestamp_to_use) 210 except Exception: 211 return JSONResponse( 212 status_code=400, 213 content={"error": "Invalid ISO timestamp given"}, 214 ) 215 if ( 216 datetime.now(timezone.utc) - request_time 217 ).total_seconds() > self.request_staleness: 218 return JSONResponse( 219 status_code=400, content={"error": "Request is too stale"} 220 ) 221 return None 222 223 def _check_inputs( 224 self, request: Request, body: bytes, module_key: Ss58Address 225 ): 226 required_headers = ["x-signature", "x-key", "x-crypto"] 227 optional_headers = ["x-timestamp"] 228 229 # TODO: we'll replace this by a Result ADT :) 230 match self._get_headers_dict( 231 request.headers, required_headers, optional_headers 232 ): 233 case (False, error): 234 return (False, error) 235 case (True, headers_dict): 236 pass 237 238 # TODO: we'll replace this by a Result ADT :) 239 match self._check_signature(headers_dict, body, module_key): 240 case (False, error): 241 return (False, error) 242 case (True, _): 243 pass 244 245 # TODO: we'll replace this by a Result ADT :) 246 match self._check_key_registered( 247 self.subnets_whitelist, 248 headers_dict, 249 self.blockchain_cache, 250 self.host_key, 251 self.use_testnet, 252 ): 253 case (False, error): 254 return (False, error) 255 case (True, _): 256 pass 257 258 return (True, None) 259 260 def _get_headers_dict( 261 self, 262 headers: starlette.datastructures.Headers, 263 required: list[str], 264 optional: list[str], 265 ): 266 headers_dict: dict[str, str] = {} 267 for required_header in required: 268 value = headers.get(required_header) 269 if not value: 270 code = 400 271 return False, json_error( 272 code, f"Missing header: {required_header}" 273 ) 274 headers_dict[required_header] = value 275 for optional_header in optional: 276 value = headers.get(optional_header) 277 if value: 278 headers_dict[optional_header] = value 279 280 return True, headers_dict 281 282 def _check_signature( 283 self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address 284 ): 285 key = headers_dict["x-key"] 286 signature = headers_dict["x-signature"] 287 crypto = int(headers_dict["x-crypto"]) # TODO: better handling of this 288 289 if not is_hex_string(key): 290 reason = "X-Key should be a hex value" 291 log_reffusal(key, reason) 292 return (False, json_error(400, reason)) 293 try: 294 signature = parse_hex(signature) 295 except Exception: 296 reason = "Signature sent is not a valid hex value" 297 log_reffusal(key, reason) 298 return False, json_error(400, reason) 299 try: 300 key = parse_hex(key) 301 except Exception: 302 reason = "Key sent is not a valid hex value" 303 log_reffusal(key, reason) 304 return False, json_error(400, reason) 305 # decodes the key for better logging 306 key_ss58 = try_ss58_decode(key) 307 if key_ss58 is None: 308 reason = "Caller key could not be decoded into a ss58address" 309 log_reffusal(key.decode(), reason) 310 return (False, json_error(400, reason)) 311 312 timestamp = headers_dict.get("x-timestamp") 313 legacy_verified = False 314 if timestamp: 315 # tries to do a legacy verification 316 json_body = json.loads(body) 317 json_body["timestamp"] = timestamp 318 stamped_body = json.dumps(json_body).encode() 319 legacy_verified = signer.verify( 320 key, crypto, stamped_body, signature 321 ) 322 323 verified = signer.verify(key, crypto, body, signature) 324 if not verified and not legacy_verified: 325 reason = "Signature doesn't match" 326 log_reffusal(key_ss58, reason) 327 return (False, json_error(401, "Signatures doesn't match")) 328 329 body_dict: dict[str, dict[str, Any]] = json.loads(body) 330 target_key = body_dict["params"].get("target_key", None) 331 if not target_key or target_key != module_key: 332 reason = "Wrong target_key in body" 333 log_reffusal(key_ss58, reason) 334 return (False, json_error(401, "Wrong target_key in body")) 335 336 return (True, None) 337 338 def _check_key_registered( 339 self, 340 subnets_whitelist: list[int] | None, 341 headers_dict: dict[str, str], 342 blockchain_cache: TTLDict[str, list[Ss58Address]], 343 host_key: Keypair, 344 use_testnet: bool, 345 ): 346 key = headers_dict["x-key"] 347 if not is_hex_string(key): 348 return (False, json_error(400, "X-Key should be a hex value")) 349 key = parse_hex(key) 350 351 # TODO: checking for key being registered should be smarter 352 # e.g. query and store all registered modules periodically. 353 354 ss58 = try_ss58_decode(key) 355 if ss58 is None: 356 reason = "Caller key could not be decoded into a ss58address" 357 log_reffusal(key.decode(), reason) 358 return (False, json_error(400, reason)) 359 360 # If subnets whitelist is specified, checks if key is registered in one 361 # of the given subnets 362 363 allowed_subnets: dict[int, bool] = {} 364 caller_subnets: list[int] = [] 365 if subnets_whitelist is not None: 366 367 def query_keys(subnet: int): 368 try: 369 node_url = get_node_url(None, use_testnet=use_testnet) 370 client = make_client( 371 node_url 372 ) # TODO: get client from outer context 373 return [*client.query_map_key(subnet).values()] 374 except Exception: 375 log("WARNING: Could not connect to a blockchain node") 376 return_list: list[Ss58Address] = [] 377 return return_list 378 379 # TODO: client pool for entire module server 380 381 got_keys = False 382 no_keys_reason = ( 383 "Miner could not connect to a blockchain node " 384 "or there is no key registered on the subnet(s) {} " 385 ) 386 for subnet in subnets_whitelist: 387 get_keys_on_subnet = partial(query_keys, subnet) 388 cache_key = f"keys_on_subnet_{subnet}" 389 keys_on_subnet = blockchain_cache.get_or_insert_lazy( 390 cache_key, get_keys_on_subnet 391 ) 392 if len(keys_on_subnet) == 0: 393 reason = no_keys_reason.format(subnet) 394 log(f"WARNING: {reason}") 395 else: 396 got_keys = True 397 if host_key.ss58_address not in keys_on_subnet: 398 log( 399 f"WARNING: This miner is deregistered on subnet {subnet}" 400 ) 401 else: 402 allowed_subnets[subnet] = True 403 if ss58 in keys_on_subnet: 404 caller_subnets.append(subnet) 405 if not got_keys: 406 return False, json_error( 407 503, no_keys_reason.format(subnets_whitelist) 408 ) 409 if not allowed_subnets: 410 log("WARNING: Miner is not registered on any subnet") 411 return False, json_error( 412 403, "Miner is not registered on any subnet" 413 ) 414 415 # searches for a common subnet between caller and miner 416 # TODO: use sets 417 allowed_subnets = { 418 subnet: allowed 419 for subnet, allowed in allowed_subnets.items() 420 if (subnet in caller_subnets) 421 } 422 if not allowed_subnets: 423 reason = "Caller key is not registered in any subnet that the miner is" 424 log_reffusal(ss58, reason) 425 return False, json_error(403, reason) 426 else: 427 # accepts everything 428 pass 429 430 return (True, None) 431 432 433def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]: 434 class CheckListsRoute(APIRoute): 435 def get_route_handler(self): 436 original_route_handler = super().get_route_handler() 437 438 async def custom_route_handler( 439 request: Request, 440 ) -> Response | JSONResponse: 441 if not request.url.path.startswith("/method"): 442 unhandled_response: Response = await original_route_handler( 443 request 444 ) 445 return unhandled_response 446 for verifier in verifiers: 447 response = await verifier.verify(request) 448 if response is not None: 449 return response 450 451 original_response: Response = await original_route_handler( 452 request 453 ) 454 return original_response 455 456 return custom_route_handler 457 458 return CheckListsRoute
47class AbstractVerifier(Protocol): 48 @abstractmethod 49 async def verify(self, request: Request) -> JSONResponse | None: 50 """Please dont mutate the request D:""" 51 ...
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
54class StakeLimiterVerifier(AbstractVerifier): 55 def __init__( 56 self, 57 subnets_whitelist: list[int] | None, 58 params_: StakeLimiterParams | None, 59 ): 60 self.subnets_whitelist = subnets_whitelist 61 self.params_ = params_ 62 if not self.params_: 63 params = StakeLimiterParams() 64 else: 65 params = self.params_ 66 self.limiter = StakeLimiter( 67 self.subnets_whitelist, 68 epoch=params.epoch, 69 max_cache_age=params.cache_age, 70 get_refill_rate=params.get_refill_per_epoch, 71 ) 72 73 async def verify(self, request: Request): 74 if request.client is None: 75 response = JSONResponse( 76 status_code=401, 77 content={"error": "Address should be present in request"}, 78 ) 79 return response 80 81 key = request.headers.get("x-key") 82 if not key: 83 response = JSONResponse( 84 status_code=401, 85 content={"error": "Valid X-Key not provided on headers"}, 86 ) 87 return response 88 89 is_allowed = await self.limiter.allow(key) 90 91 if not is_allowed: 92 response = JSONResponse( 93 status_code=429, 94 headers={ 95 "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds" 96 }, 97 content={"error": "Rate limit exceeded"}, 98 ) 99 return response 100 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
55 def __init__( 56 self, 57 subnets_whitelist: list[int] | None, 58 params_: StakeLimiterParams | None, 59 ): 60 self.subnets_whitelist = subnets_whitelist 61 self.params_ = params_ 62 if not self.params_: 63 params = StakeLimiterParams() 64 else: 65 params = self.params_ 66 self.limiter = StakeLimiter( 67 self.subnets_whitelist, 68 epoch=params.epoch, 69 max_cache_age=params.cache_age, 70 get_refill_rate=params.get_refill_per_epoch, 71 )
73 async def verify(self, request: Request): 74 if request.client is None: 75 response = JSONResponse( 76 status_code=401, 77 content={"error": "Address should be present in request"}, 78 ) 79 return response 80 81 key = request.headers.get("x-key") 82 if not key: 83 response = JSONResponse( 84 status_code=401, 85 content={"error": "Valid X-Key not provided on headers"}, 86 ) 87 return response 88 89 is_allowed = await self.limiter.allow(key) 90 91 if not is_allowed: 92 response = JSONResponse( 93 status_code=429, 94 headers={ 95 "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds" 96 }, 97 content={"error": "Rate limit exceeded"}, 98 ) 99 return response 100 return None
Please dont mutate the request D:
103class ListVerifier(AbstractVerifier): 104 def __init__( 105 self, 106 blacklist: list[Ss58Address] | None, 107 whitelist: list[Ss58Address] | None, 108 ip_blacklist: list[str] | None, 109 ): 110 self.blacklist = blacklist 111 self.whitelist = whitelist 112 self.ip_blacklist = ip_blacklist 113 114 async def verify(self, request: Request) -> JSONResponse | None: 115 key = request.headers.get("x-key") 116 if not key: 117 reason = "Missing header: X-Key" 118 log(f"INFO: refusing module request because: {reason}") 119 return json_error(400, "Missing header: X-Key") 120 121 ss58 = try_ss58_decode(key) 122 if ss58 is None: 123 reason = "Caller key could not be decoded into a ss58address" 124 log_reffusal(key, reason) 125 return json_error(400, reason) 126 if request.client is None: 127 return json_error(400, "Address should be present in request") 128 if self.blacklist and ss58 in self.blacklist: 129 return json_error(403, "You are blacklisted") 130 if self.ip_blacklist and request.client.host in self.ip_blacklist: 131 return json_error(403, "Your IP is blacklisted") 132 if self.whitelist and ss58 not in self.whitelist: 133 return json_error(403, "You are not whitelisted") 134 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
114 async def verify(self, request: Request) -> JSONResponse | None: 115 key = request.headers.get("x-key") 116 if not key: 117 reason = "Missing header: X-Key" 118 log(f"INFO: refusing module request because: {reason}") 119 return json_error(400, "Missing header: X-Key") 120 121 ss58 = try_ss58_decode(key) 122 if ss58 is None: 123 reason = "Caller key could not be decoded into a ss58address" 124 log_reffusal(key, reason) 125 return json_error(400, reason) 126 if request.client is None: 127 return json_error(400, "Address should be present in request") 128 if self.blacklist and ss58 in self.blacklist: 129 return json_error(403, "You are blacklisted") 130 if self.ip_blacklist and request.client.host in self.ip_blacklist: 131 return json_error(403, "Your IP is blacklisted") 132 if self.whitelist and ss58 not in self.whitelist: 133 return json_error(403, "You are not whitelisted") 134 return None
Please dont mutate the request D:
137class IpLimiterVerifier(AbstractVerifier): 138 def __init__( 139 self, 140 params: IpLimiterParams | None, 141 ): 142 """ 143 :param limiter: KeyLimiter instance OR None 144 145 If limiter is None, then a default TokenBucketLimiter is used with the following config: 146 bucket_size=200, refill_rate=15 147 """ 148 149 # fallback to default limiter 150 if not params: 151 params = IpLimiterParams() 152 self._limiter = TokenBucketLimiter( 153 bucket_size=params.bucket_size, refill_rate=params.refill_rate 154 ) 155 156 async def verify(self, request: Request): 157 assert request.client is not None, "request is invalid" 158 assert request.client.host, "request is invalid." 159 160 ip = request.client.host 161 162 is_allowed = self._limiter.allow(ip) 163 164 if not is_allowed: 165 response = JSONResponse( 166 status_code=429, 167 headers={ 168 "X-RateLimit-Remaining": str(self._limiter.remaining(ip)) 169 }, 170 content={"error": "Rate limit exceeded"}, 171 ) 172 return response 173 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
138 def __init__( 139 self, 140 params: IpLimiterParams | None, 141 ): 142 """ 143 :param limiter: KeyLimiter instance OR None 144 145 If limiter is None, then a default TokenBucketLimiter is used with the following config: 146 bucket_size=200, refill_rate=15 147 """ 148 149 # fallback to default limiter 150 if not params: 151 params = IpLimiterParams() 152 self._limiter = TokenBucketLimiter( 153 bucket_size=params.bucket_size, refill_rate=params.refill_rate 154 )
Parameters
- limiter: KeyLimiter instance OR None
If limiter is None, then a default TokenBucketLimiter is used with the following config: bucket_size=200, refill_rate=15
156 async def verify(self, request: Request): 157 assert request.client is not None, "request is invalid" 158 assert request.client.host, "request is invalid." 159 160 ip = request.client.host 161 162 is_allowed = self._limiter.allow(ip) 163 164 if not is_allowed: 165 response = JSONResponse( 166 status_code=429, 167 headers={ 168 "X-RateLimit-Remaining": str(self._limiter.remaining(ip)) 169 }, 170 content={"error": "Rate limit exceeded"}, 171 ) 172 return response 173 return None
Please dont mutate the request D:
176class InputHandlerVerifier(AbstractVerifier): 177 def __init__( 178 self, 179 subnets_whitelist: list[int] | None, 180 module_key: Ss58Address, 181 request_staleness: int, 182 blockchain_cache: TTLDict[str, list[Ss58Address]], 183 host_key: Keypair, 184 use_testnet: bool, 185 ): 186 self.subnets_whitelist = subnets_whitelist 187 self.module_key = module_key 188 self.request_staleness = request_staleness 189 self.blockchain_cache = blockchain_cache 190 self.host_key = host_key 191 self.use_testnet = use_testnet 192 193 async def verify(self, request: Request): 194 body = await request.body() 195 196 # TODO: we'll replace this by a Result ADT :) 197 match self._check_inputs(request, body, self.module_key): 198 case (False, error): 199 return error 200 case (True, _): 201 pass 202 203 body_dict: dict[str, dict[str, Any]] = json.loads(body) 204 timestamp = body_dict["params"].get("timestamp", None) 205 legacy_timestamp = request.headers.get("X-Timestamp", None) 206 try: 207 timestamp_to_use = ( 208 timestamp if not legacy_timestamp else legacy_timestamp 209 ) 210 request_time = datetime.fromisoformat(timestamp_to_use) 211 except Exception: 212 return JSONResponse( 213 status_code=400, 214 content={"error": "Invalid ISO timestamp given"}, 215 ) 216 if ( 217 datetime.now(timezone.utc) - request_time 218 ).total_seconds() > self.request_staleness: 219 return JSONResponse( 220 status_code=400, content={"error": "Request is too stale"} 221 ) 222 return None 223 224 def _check_inputs( 225 self, request: Request, body: bytes, module_key: Ss58Address 226 ): 227 required_headers = ["x-signature", "x-key", "x-crypto"] 228 optional_headers = ["x-timestamp"] 229 230 # TODO: we'll replace this by a Result ADT :) 231 match self._get_headers_dict( 232 request.headers, required_headers, optional_headers 233 ): 234 case (False, error): 235 return (False, error) 236 case (True, headers_dict): 237 pass 238 239 # TODO: we'll replace this by a Result ADT :) 240 match self._check_signature(headers_dict, body, module_key): 241 case (False, error): 242 return (False, error) 243 case (True, _): 244 pass 245 246 # TODO: we'll replace this by a Result ADT :) 247 match self._check_key_registered( 248 self.subnets_whitelist, 249 headers_dict, 250 self.blockchain_cache, 251 self.host_key, 252 self.use_testnet, 253 ): 254 case (False, error): 255 return (False, error) 256 case (True, _): 257 pass 258 259 return (True, None) 260 261 def _get_headers_dict( 262 self, 263 headers: starlette.datastructures.Headers, 264 required: list[str], 265 optional: list[str], 266 ): 267 headers_dict: dict[str, str] = {} 268 for required_header in required: 269 value = headers.get(required_header) 270 if not value: 271 code = 400 272 return False, json_error( 273 code, f"Missing header: {required_header}" 274 ) 275 headers_dict[required_header] = value 276 for optional_header in optional: 277 value = headers.get(optional_header) 278 if value: 279 headers_dict[optional_header] = value 280 281 return True, headers_dict 282 283 def _check_signature( 284 self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address 285 ): 286 key = headers_dict["x-key"] 287 signature = headers_dict["x-signature"] 288 crypto = int(headers_dict["x-crypto"]) # TODO: better handling of this 289 290 if not is_hex_string(key): 291 reason = "X-Key should be a hex value" 292 log_reffusal(key, reason) 293 return (False, json_error(400, reason)) 294 try: 295 signature = parse_hex(signature) 296 except Exception: 297 reason = "Signature sent is not a valid hex value" 298 log_reffusal(key, reason) 299 return False, json_error(400, reason) 300 try: 301 key = parse_hex(key) 302 except Exception: 303 reason = "Key sent is not a valid hex value" 304 log_reffusal(key, reason) 305 return False, json_error(400, reason) 306 # decodes the key for better logging 307 key_ss58 = try_ss58_decode(key) 308 if key_ss58 is None: 309 reason = "Caller key could not be decoded into a ss58address" 310 log_reffusal(key.decode(), reason) 311 return (False, json_error(400, reason)) 312 313 timestamp = headers_dict.get("x-timestamp") 314 legacy_verified = False 315 if timestamp: 316 # tries to do a legacy verification 317 json_body = json.loads(body) 318 json_body["timestamp"] = timestamp 319 stamped_body = json.dumps(json_body).encode() 320 legacy_verified = signer.verify( 321 key, crypto, stamped_body, signature 322 ) 323 324 verified = signer.verify(key, crypto, body, signature) 325 if not verified and not legacy_verified: 326 reason = "Signature doesn't match" 327 log_reffusal(key_ss58, reason) 328 return (False, json_error(401, "Signatures doesn't match")) 329 330 body_dict: dict[str, dict[str, Any]] = json.loads(body) 331 target_key = body_dict["params"].get("target_key", None) 332 if not target_key or target_key != module_key: 333 reason = "Wrong target_key in body" 334 log_reffusal(key_ss58, reason) 335 return (False, json_error(401, "Wrong target_key in body")) 336 337 return (True, None) 338 339 def _check_key_registered( 340 self, 341 subnets_whitelist: list[int] | None, 342 headers_dict: dict[str, str], 343 blockchain_cache: TTLDict[str, list[Ss58Address]], 344 host_key: Keypair, 345 use_testnet: bool, 346 ): 347 key = headers_dict["x-key"] 348 if not is_hex_string(key): 349 return (False, json_error(400, "X-Key should be a hex value")) 350 key = parse_hex(key) 351 352 # TODO: checking for key being registered should be smarter 353 # e.g. query and store all registered modules periodically. 354 355 ss58 = try_ss58_decode(key) 356 if ss58 is None: 357 reason = "Caller key could not be decoded into a ss58address" 358 log_reffusal(key.decode(), reason) 359 return (False, json_error(400, reason)) 360 361 # If subnets whitelist is specified, checks if key is registered in one 362 # of the given subnets 363 364 allowed_subnets: dict[int, bool] = {} 365 caller_subnets: list[int] = [] 366 if subnets_whitelist is not None: 367 368 def query_keys(subnet: int): 369 try: 370 node_url = get_node_url(None, use_testnet=use_testnet) 371 client = make_client( 372 node_url 373 ) # TODO: get client from outer context 374 return [*client.query_map_key(subnet).values()] 375 except Exception: 376 log("WARNING: Could not connect to a blockchain node") 377 return_list: list[Ss58Address] = [] 378 return return_list 379 380 # TODO: client pool for entire module server 381 382 got_keys = False 383 no_keys_reason = ( 384 "Miner could not connect to a blockchain node " 385 "or there is no key registered on the subnet(s) {} " 386 ) 387 for subnet in subnets_whitelist: 388 get_keys_on_subnet = partial(query_keys, subnet) 389 cache_key = f"keys_on_subnet_{subnet}" 390 keys_on_subnet = blockchain_cache.get_or_insert_lazy( 391 cache_key, get_keys_on_subnet 392 ) 393 if len(keys_on_subnet) == 0: 394 reason = no_keys_reason.format(subnet) 395 log(f"WARNING: {reason}") 396 else: 397 got_keys = True 398 if host_key.ss58_address not in keys_on_subnet: 399 log( 400 f"WARNING: This miner is deregistered on subnet {subnet}" 401 ) 402 else: 403 allowed_subnets[subnet] = True 404 if ss58 in keys_on_subnet: 405 caller_subnets.append(subnet) 406 if not got_keys: 407 return False, json_error( 408 503, no_keys_reason.format(subnets_whitelist) 409 ) 410 if not allowed_subnets: 411 log("WARNING: Miner is not registered on any subnet") 412 return False, json_error( 413 403, "Miner is not registered on any subnet" 414 ) 415 416 # searches for a common subnet between caller and miner 417 # TODO: use sets 418 allowed_subnets = { 419 subnet: allowed 420 for subnet, allowed in allowed_subnets.items() 421 if (subnet in caller_subnets) 422 } 423 if not allowed_subnets: 424 reason = "Caller key is not registered in any subnet that the miner is" 425 log_reffusal(ss58, reason) 426 return False, json_error(403, reason) 427 else: 428 # accepts everything 429 pass 430 431 return (True, None)
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
177 def __init__( 178 self, 179 subnets_whitelist: list[int] | None, 180 module_key: Ss58Address, 181 request_staleness: int, 182 blockchain_cache: TTLDict[str, list[Ss58Address]], 183 host_key: Keypair, 184 use_testnet: bool, 185 ): 186 self.subnets_whitelist = subnets_whitelist 187 self.module_key = module_key 188 self.request_staleness = request_staleness 189 self.blockchain_cache = blockchain_cache 190 self.host_key = host_key 191 self.use_testnet = use_testnet
193 async def verify(self, request: Request): 194 body = await request.body() 195 196 # TODO: we'll replace this by a Result ADT :) 197 match self._check_inputs(request, body, self.module_key): 198 case (False, error): 199 return error 200 case (True, _): 201 pass 202 203 body_dict: dict[str, dict[str, Any]] = json.loads(body) 204 timestamp = body_dict["params"].get("timestamp", None) 205 legacy_timestamp = request.headers.get("X-Timestamp", None) 206 try: 207 timestamp_to_use = ( 208 timestamp if not legacy_timestamp else legacy_timestamp 209 ) 210 request_time = datetime.fromisoformat(timestamp_to_use) 211 except Exception: 212 return JSONResponse( 213 status_code=400, 214 content={"error": "Invalid ISO timestamp given"}, 215 ) 216 if ( 217 datetime.now(timezone.utc) - request_time 218 ).total_seconds() > self.request_staleness: 219 return JSONResponse( 220 status_code=400, content={"error": "Request is too stale"} 221 ) 222 return None
Please dont mutate the request D:
434def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]: 435 class CheckListsRoute(APIRoute): 436 def get_route_handler(self): 437 original_route_handler = super().get_route_handler() 438 439 async def custom_route_handler( 440 request: Request, 441 ) -> Response | JSONResponse: 442 if not request.url.path.startswith("/method"): 443 unhandled_response: Response = await original_route_handler( 444 request 445 ) 446 return unhandled_response 447 for verifier in verifiers: 448 response = await verifier.verify(request) 449 if response is not None: 450 return response 451 452 original_response: Response = await original_route_handler( 453 request 454 ) 455 return original_response 456 457 return custom_route_handler 458 459 return CheckListsRoute