communex.module.routers.module_routers
1import json 2import re 3from abc import abstractmethod 4from datetime import datetime, timezone 5from functools import partial 6from typing import Any, Protocol, Sequence 7 8import starlette.datastructures 9from communex._common import get_node_url 10from communex.module import _signer as signer 11from communex.module._rate_limiters._stake_limiter import StakeLimiter 12from communex.module._rate_limiters.limiters import ( 13 IpLimiterParams, 14 StakeLimiterParams, 15) 16from communex.module._util import ( 17 json_error, 18 log, 19 log_reffusal, 20 make_client, 21 try_ss58_decode, 22) 23from communex.types import Ss58Address 24from communex.util.memo import TTLDict 25from fastapi import Request, Response 26from fastapi.responses import JSONResponse 27from fastapi.routing import APIRoute 28from keylimiter import TokenBucketLimiter 29from substrateinterface import Keypair 30 31HEX_PATTERN = re.compile(r"^[0-9a-fA-F]+$") 32 33 34def is_hex_string(string: str): 35 return bool(HEX_PATTERN.match(string)) 36 37 38def parse_hex(hex_str: str) -> bytes: 39 if hex_str[0:2] == "0x": 40 return bytes.fromhex(hex_str[2:]) 41 else: 42 return bytes.fromhex(hex_str) 43 44 45class AbstractVerifier(Protocol): 46 @abstractmethod 47 async def verify(self, request: Request) -> JSONResponse | None: 48 """Please dont mutate the request D:""" 49 ... 50 51 52class StakeLimiterVerifier(AbstractVerifier): 53 def __init__( 54 self, 55 subnets_whitelist: list[int] | None, 56 params_: StakeLimiterParams | None, 57 ): 58 self.subnets_whitelist = subnets_whitelist 59 self.params_ = params_ 60 if not self.params_: 61 params = StakeLimiterParams() 62 else: 63 params = self.params_ 64 self.limiter = StakeLimiter( 65 self.subnets_whitelist, 66 epoch=params.epoch, 67 max_cache_age=params.cache_age, 68 get_refill_rate=params.get_refill_per_epoch, 69 ) 70 71 async def verify(self, request: Request): 72 if request.client is None: 73 response = JSONResponse( 74 status_code=401, 75 content={"error": "Address should be present in request"}, 76 ) 77 return response 78 79 key = request.headers.get("x-key") 80 if not key: 81 response = JSONResponse( 82 status_code=401, 83 content={"error": "Valid X-Key not provided on headers"}, 84 ) 85 return response 86 87 is_allowed = await self.limiter.allow(key) 88 89 if not is_allowed: 90 response = JSONResponse( 91 status_code=429, 92 headers={ 93 "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds" 94 }, 95 content={"error": "Rate limit exceeded"}, 96 ) 97 return response 98 return None 99 100 101class ListVerifier(AbstractVerifier): 102 def __init__( 103 self, 104 blacklist: list[Ss58Address] | None, 105 whitelist: list[Ss58Address] | None, 106 ip_blacklist: list[str] | None, 107 ): 108 self.blacklist = blacklist 109 self.whitelist = whitelist 110 self.ip_blacklist = ip_blacklist 111 112 async def verify(self, request: Request) -> JSONResponse | None: 113 key = request.headers.get("x-key") 114 if not key: 115 reason = "Missing header: X-Key" 116 log(f"INFO: refusing module request because: {reason}") 117 return json_error(400, "Missing header: X-Key") 118 119 ss58 = try_ss58_decode(key) 120 if ss58 is None: 121 reason = "Caller key could not be decoded into a ss58address" 122 log_reffusal(key, reason) 123 return json_error(400, reason) 124 if request.client is None: 125 return json_error(400, "Address should be present in request") 126 if self.blacklist and ss58 in self.blacklist: 127 return json_error(403, "You are blacklisted") 128 if self.ip_blacklist and request.client.host in self.ip_blacklist: 129 return json_error(403, "Your IP is blacklisted") 130 if self.whitelist and ss58 not in self.whitelist: 131 return json_error(403, "You are not whitelisted") 132 return None 133 134 135class IpLimiterVerifier(AbstractVerifier): 136 def __init__( 137 self, 138 params: IpLimiterParams | None, 139 ): 140 """ 141 :param limiter: KeyLimiter instance OR None 142 143 If limiter is None, then a default TokenBucketLimiter is used with the following config: 144 bucket_size=200, refill_rate=15 145 """ 146 147 # fallback to default limiter 148 if not params: 149 params = IpLimiterParams() 150 self._limiter = TokenBucketLimiter( 151 bucket_size=params.bucket_size, refill_rate=params.refill_rate 152 ) 153 154 async def verify(self, request: Request): 155 assert request.client is not None, "request is invalid" 156 assert request.client.host, "request is invalid." 157 158 ip = request.client.host 159 160 is_allowed = self._limiter.allow(ip) 161 162 if not is_allowed: 163 response = JSONResponse( 164 status_code=429, 165 headers={ 166 "X-RateLimit-Remaining": str(self._limiter.remaining(ip)) 167 }, 168 content={"error": "Rate limit exceeded"}, 169 ) 170 return response 171 return None 172 173 174class InputHandlerVerifier(AbstractVerifier): 175 def __init__( 176 self, 177 subnets_whitelist: list[int] | None, 178 module_key: Ss58Address, 179 request_staleness: int, 180 blockchain_cache: TTLDict[str, list[Ss58Address]], 181 host_key: Keypair, 182 use_testnet: bool, 183 ): 184 self.subnets_whitelist = subnets_whitelist 185 self.module_key = module_key 186 self.request_staleness = request_staleness 187 self.blockchain_cache = blockchain_cache 188 self.host_key = host_key 189 self.use_testnet = use_testnet 190 191 async def verify(self, request: Request): 192 body = await request.body() 193 194 # TODO: we'll replace this by a Result ADT :) 195 match self._check_inputs(request, body, self.module_key): 196 case (False, error): 197 return error 198 case (True, _): 199 pass 200 201 body_dict: dict[str, dict[str, Any]] = json.loads(body) 202 timestamp = body_dict["params"].get("timestamp", None) 203 legacy_timestamp = request.headers.get("X-Timestamp", None) 204 try: 205 timestamp_to_use = ( 206 timestamp if not legacy_timestamp else legacy_timestamp 207 ) 208 request_time = datetime.fromisoformat(timestamp_to_use) 209 except Exception: 210 return JSONResponse( 211 status_code=400, 212 content={"error": "Invalid ISO timestamp given"}, 213 ) 214 if ( 215 datetime.now(timezone.utc) - request_time 216 ).total_seconds() > self.request_staleness: 217 return JSONResponse( 218 status_code=400, content={"error": "Request is too stale"} 219 ) 220 return None 221 222 def _check_inputs( 223 self, request: Request, body: bytes, module_key: Ss58Address 224 ): 225 required_headers = ["x-signature", "x-key", "x-crypto"] 226 optional_headers = ["x-timestamp"] 227 228 # TODO: we'll replace this by a Result ADT :) 229 match self._get_headers_dict( 230 request.headers, required_headers, optional_headers 231 ): 232 case (False, error): 233 return (False, error) 234 case (True, headers_dict): 235 pass 236 237 # TODO: we'll replace this by a Result ADT :) 238 match self._check_signature(headers_dict, body, module_key): 239 case (False, error): 240 return (False, error) 241 case (True, _): 242 pass 243 244 # TODO: we'll replace this by a Result ADT :) 245 match self._check_key_registered( 246 self.subnets_whitelist, 247 headers_dict, 248 self.blockchain_cache, 249 self.host_key, 250 self.use_testnet, 251 ): 252 case (False, error): 253 return (False, error) 254 case (True, _): 255 pass 256 257 return (True, None) 258 259 def _get_headers_dict( 260 self, 261 headers: starlette.datastructures.Headers, 262 required: list[str], 263 optional: list[str], 264 ): 265 headers_dict: dict[str, str] = {} 266 for required_header in required: 267 value = headers.get(required_header) 268 if not value: 269 code = 400 270 return False, json_error( 271 code, f"Missing header: {required_header}" 272 ) 273 headers_dict[required_header] = value 274 for optional_header in optional: 275 value = headers.get(optional_header) 276 if value: 277 headers_dict[optional_header] = value 278 279 return True, headers_dict 280 281 def _check_signature( 282 self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address 283 ): 284 key = headers_dict["x-key"] 285 signature = headers_dict["x-signature"] 286 crypto = int(headers_dict["x-crypto"]) # TODO: better handling of this 287 288 if not is_hex_string(key): 289 reason = "X-Key should be a hex value" 290 log_reffusal(key, reason) 291 return (False, json_error(400, reason)) 292 try: 293 signature = parse_hex(signature) 294 except Exception: 295 reason = "Signature sent is not a valid hex value" 296 log_reffusal(key, reason) 297 return False, json_error(400, reason) 298 try: 299 key = parse_hex(key) 300 except Exception: 301 reason = "Key sent is not a valid hex value" 302 log_reffusal(key, reason) 303 return False, json_error(400, reason) 304 # decodes the key for better logging 305 key_ss58 = try_ss58_decode(key) 306 if key_ss58 is None: 307 reason = "Caller key could not be decoded into a ss58address" 308 log_reffusal(key.decode(), reason) 309 return (False, json_error(400, reason)) 310 311 timestamp = headers_dict.get("x-timestamp") 312 legacy_verified = False 313 if timestamp: 314 # tries to do a legacy verification 315 json_body = json.loads(body) 316 json_body["timestamp"] = timestamp 317 stamped_body = json.dumps(json_body).encode() 318 legacy_verified = signer.verify( 319 key, crypto, stamped_body, signature 320 ) 321 322 verified = signer.verify(key, crypto, body, signature) 323 if not verified and not legacy_verified: 324 reason = "Signature doesn't match" 325 log_reffusal(key_ss58, reason) 326 return (False, json_error(401, "Signatures doesn't match")) 327 328 body_dict: dict[str, dict[str, Any]] = json.loads(body) 329 target_key = body_dict["params"].get("target_key", None) 330 if not target_key or target_key != module_key: 331 reason = "Wrong target_key in body" 332 log_reffusal(key_ss58, reason) 333 return (False, json_error(401, "Wrong target_key in body")) 334 335 return (True, None) 336 337 def _check_key_registered( 338 self, 339 subnets_whitelist: list[int] | None, 340 headers_dict: dict[str, str], 341 blockchain_cache: TTLDict[str, list[Ss58Address]], 342 host_key: Keypair, 343 use_testnet: bool, 344 ): 345 key = headers_dict["x-key"] 346 if not is_hex_string(key): 347 return (False, json_error(400, "X-Key should be a hex value")) 348 key = parse_hex(key) 349 350 # TODO: checking for key being registered should be smarter 351 # e.g. query and store all registered modules periodically. 352 353 ss58 = try_ss58_decode(key) 354 if ss58 is None: 355 reason = "Caller key could not be decoded into a ss58address" 356 log_reffusal(key.decode(), reason) 357 return (False, json_error(400, reason)) 358 359 # If subnets whitelist is specified, checks if key is registered in one 360 # of the given subnets 361 362 allowed_subnets: dict[int, bool] = {} 363 caller_subnets: list[int] = [] 364 if subnets_whitelist is not None: 365 366 def query_keys(subnet: int): 367 try: 368 node_url = get_node_url(None, use_testnet=use_testnet) 369 client = make_client( 370 node_url 371 ) # TODO: get client from outer context 372 return [*client.query_map_key(subnet).values()] 373 except Exception: 374 log("WARNING: Could not connect to a blockchain node") 375 return_list: list[Ss58Address] = [] 376 return return_list 377 378 # TODO: client pool for entire module server 379 380 got_keys = False 381 no_keys_reason = ( 382 "Miner could not connect to a blockchain node " 383 "or there is no key registered on the subnet(s) {} " 384 ) 385 for subnet in subnets_whitelist: 386 get_keys_on_subnet = partial(query_keys, subnet) 387 cache_key = f"keys_on_subnet_{subnet}" 388 keys_on_subnet = blockchain_cache.get_or_insert_lazy( 389 cache_key, get_keys_on_subnet 390 ) 391 if len(keys_on_subnet) == 0: 392 reason = no_keys_reason.format(subnet) 393 log(f"WARNING: {reason}") 394 else: 395 got_keys = True 396 if host_key.ss58_address not in keys_on_subnet: 397 log( 398 f"WARNING: This miner is deregistered on subnet {subnet}" 399 ) 400 else: 401 allowed_subnets[subnet] = True 402 if ss58 in keys_on_subnet: 403 caller_subnets.append(subnet) 404 if not got_keys: 405 return False, json_error( 406 503, no_keys_reason.format(subnets_whitelist) 407 ) 408 if not allowed_subnets: 409 log("WARNING: Miner is not registered on any subnet") 410 return False, json_error( 411 403, "Miner is not registered on any subnet" 412 ) 413 414 # searches for a common subnet between caller and miner 415 # TODO: use sets 416 allowed_subnets = { 417 subnet: allowed 418 for subnet, allowed in allowed_subnets.items() 419 if (subnet in caller_subnets) 420 } 421 if not allowed_subnets: 422 reason = "Caller key is not registered in any subnet that the miner is" 423 log_reffusal(ss58, reason) 424 return False, json_error(403, reason) 425 else: 426 # accepts everything 427 pass 428 429 return (True, None) 430 431 432def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]: 433 class CheckListsRoute(APIRoute): 434 def get_route_handler(self): 435 original_route_handler = super().get_route_handler() 436 437 async def custom_route_handler( 438 request: Request, 439 ) -> Response | JSONResponse: 440 if not request.url.path.startswith("/method"): 441 unhandled_response: Response = await original_route_handler( 442 request 443 ) 444 return unhandled_response 445 for verifier in verifiers: 446 response = await verifier.verify(request) 447 if response is not None: 448 return response 449 450 original_response: Response = await original_route_handler( 451 request 452 ) 453 return original_response 454 455 return custom_route_handler 456 457 return CheckListsRoute
46class AbstractVerifier(Protocol): 47 @abstractmethod 48 async def verify(self, request: Request) -> JSONResponse | None: 49 """Please dont mutate the request D:""" 50 ...
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
53class StakeLimiterVerifier(AbstractVerifier): 54 def __init__( 55 self, 56 subnets_whitelist: list[int] | None, 57 params_: StakeLimiterParams | None, 58 ): 59 self.subnets_whitelist = subnets_whitelist 60 self.params_ = params_ 61 if not self.params_: 62 params = StakeLimiterParams() 63 else: 64 params = self.params_ 65 self.limiter = StakeLimiter( 66 self.subnets_whitelist, 67 epoch=params.epoch, 68 max_cache_age=params.cache_age, 69 get_refill_rate=params.get_refill_per_epoch, 70 ) 71 72 async def verify(self, request: Request): 73 if request.client is None: 74 response = JSONResponse( 75 status_code=401, 76 content={"error": "Address should be present in request"}, 77 ) 78 return response 79 80 key = request.headers.get("x-key") 81 if not key: 82 response = JSONResponse( 83 status_code=401, 84 content={"error": "Valid X-Key not provided on headers"}, 85 ) 86 return response 87 88 is_allowed = await self.limiter.allow(key) 89 90 if not is_allowed: 91 response = JSONResponse( 92 status_code=429, 93 headers={ 94 "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds" 95 }, 96 content={"error": "Rate limit exceeded"}, 97 ) 98 return response 99 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
54 def __init__( 55 self, 56 subnets_whitelist: list[int] | None, 57 params_: StakeLimiterParams | None, 58 ): 59 self.subnets_whitelist = subnets_whitelist 60 self.params_ = params_ 61 if not self.params_: 62 params = StakeLimiterParams() 63 else: 64 params = self.params_ 65 self.limiter = StakeLimiter( 66 self.subnets_whitelist, 67 epoch=params.epoch, 68 max_cache_age=params.cache_age, 69 get_refill_rate=params.get_refill_per_epoch, 70 )
72 async def verify(self, request: Request): 73 if request.client is None: 74 response = JSONResponse( 75 status_code=401, 76 content={"error": "Address should be present in request"}, 77 ) 78 return response 79 80 key = request.headers.get("x-key") 81 if not key: 82 response = JSONResponse( 83 status_code=401, 84 content={"error": "Valid X-Key not provided on headers"}, 85 ) 86 return response 87 88 is_allowed = await self.limiter.allow(key) 89 90 if not is_allowed: 91 response = JSONResponse( 92 status_code=429, 93 headers={ 94 "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds" 95 }, 96 content={"error": "Rate limit exceeded"}, 97 ) 98 return response 99 return None
Please dont mutate the request D:
102class ListVerifier(AbstractVerifier): 103 def __init__( 104 self, 105 blacklist: list[Ss58Address] | None, 106 whitelist: list[Ss58Address] | None, 107 ip_blacklist: list[str] | None, 108 ): 109 self.blacklist = blacklist 110 self.whitelist = whitelist 111 self.ip_blacklist = ip_blacklist 112 113 async def verify(self, request: Request) -> JSONResponse | None: 114 key = request.headers.get("x-key") 115 if not key: 116 reason = "Missing header: X-Key" 117 log(f"INFO: refusing module request because: {reason}") 118 return json_error(400, "Missing header: X-Key") 119 120 ss58 = try_ss58_decode(key) 121 if ss58 is None: 122 reason = "Caller key could not be decoded into a ss58address" 123 log_reffusal(key, reason) 124 return json_error(400, reason) 125 if request.client is None: 126 return json_error(400, "Address should be present in request") 127 if self.blacklist and ss58 in self.blacklist: 128 return json_error(403, "You are blacklisted") 129 if self.ip_blacklist and request.client.host in self.ip_blacklist: 130 return json_error(403, "Your IP is blacklisted") 131 if self.whitelist and ss58 not in self.whitelist: 132 return json_error(403, "You are not whitelisted") 133 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
113 async def verify(self, request: Request) -> JSONResponse | None: 114 key = request.headers.get("x-key") 115 if not key: 116 reason = "Missing header: X-Key" 117 log(f"INFO: refusing module request because: {reason}") 118 return json_error(400, "Missing header: X-Key") 119 120 ss58 = try_ss58_decode(key) 121 if ss58 is None: 122 reason = "Caller key could not be decoded into a ss58address" 123 log_reffusal(key, reason) 124 return json_error(400, reason) 125 if request.client is None: 126 return json_error(400, "Address should be present in request") 127 if self.blacklist and ss58 in self.blacklist: 128 return json_error(403, "You are blacklisted") 129 if self.ip_blacklist and request.client.host in self.ip_blacklist: 130 return json_error(403, "Your IP is blacklisted") 131 if self.whitelist and ss58 not in self.whitelist: 132 return json_error(403, "You are not whitelisted") 133 return None
Please dont mutate the request D:
136class IpLimiterVerifier(AbstractVerifier): 137 def __init__( 138 self, 139 params: IpLimiterParams | None, 140 ): 141 """ 142 :param limiter: KeyLimiter instance OR None 143 144 If limiter is None, then a default TokenBucketLimiter is used with the following config: 145 bucket_size=200, refill_rate=15 146 """ 147 148 # fallback to default limiter 149 if not params: 150 params = IpLimiterParams() 151 self._limiter = TokenBucketLimiter( 152 bucket_size=params.bucket_size, refill_rate=params.refill_rate 153 ) 154 155 async def verify(self, request: Request): 156 assert request.client is not None, "request is invalid" 157 assert request.client.host, "request is invalid." 158 159 ip = request.client.host 160 161 is_allowed = self._limiter.allow(ip) 162 163 if not is_allowed: 164 response = JSONResponse( 165 status_code=429, 166 headers={ 167 "X-RateLimit-Remaining": str(self._limiter.remaining(ip)) 168 }, 169 content={"error": "Rate limit exceeded"}, 170 ) 171 return response 172 return None
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
137 def __init__( 138 self, 139 params: IpLimiterParams | None, 140 ): 141 """ 142 :param limiter: KeyLimiter instance OR None 143 144 If limiter is None, then a default TokenBucketLimiter is used with the following config: 145 bucket_size=200, refill_rate=15 146 """ 147 148 # fallback to default limiter 149 if not params: 150 params = IpLimiterParams() 151 self._limiter = TokenBucketLimiter( 152 bucket_size=params.bucket_size, refill_rate=params.refill_rate 153 )
Parameters
- limiter: KeyLimiter instance OR None
If limiter is None, then a default TokenBucketLimiter is used with the following config: bucket_size=200, refill_rate=15
155 async def verify(self, request: Request): 156 assert request.client is not None, "request is invalid" 157 assert request.client.host, "request is invalid." 158 159 ip = request.client.host 160 161 is_allowed = self._limiter.allow(ip) 162 163 if not is_allowed: 164 response = JSONResponse( 165 status_code=429, 166 headers={ 167 "X-RateLimit-Remaining": str(self._limiter.remaining(ip)) 168 }, 169 content={"error": "Rate limit exceeded"}, 170 ) 171 return response 172 return None
Please dont mutate the request D:
175class InputHandlerVerifier(AbstractVerifier): 176 def __init__( 177 self, 178 subnets_whitelist: list[int] | None, 179 module_key: Ss58Address, 180 request_staleness: int, 181 blockchain_cache: TTLDict[str, list[Ss58Address]], 182 host_key: Keypair, 183 use_testnet: bool, 184 ): 185 self.subnets_whitelist = subnets_whitelist 186 self.module_key = module_key 187 self.request_staleness = request_staleness 188 self.blockchain_cache = blockchain_cache 189 self.host_key = host_key 190 self.use_testnet = use_testnet 191 192 async def verify(self, request: Request): 193 body = await request.body() 194 195 # TODO: we'll replace this by a Result ADT :) 196 match self._check_inputs(request, body, self.module_key): 197 case (False, error): 198 return error 199 case (True, _): 200 pass 201 202 body_dict: dict[str, dict[str, Any]] = json.loads(body) 203 timestamp = body_dict["params"].get("timestamp", None) 204 legacy_timestamp = request.headers.get("X-Timestamp", None) 205 try: 206 timestamp_to_use = ( 207 timestamp if not legacy_timestamp else legacy_timestamp 208 ) 209 request_time = datetime.fromisoformat(timestamp_to_use) 210 except Exception: 211 return JSONResponse( 212 status_code=400, 213 content={"error": "Invalid ISO timestamp given"}, 214 ) 215 if ( 216 datetime.now(timezone.utc) - request_time 217 ).total_seconds() > self.request_staleness: 218 return JSONResponse( 219 status_code=400, content={"error": "Request is too stale"} 220 ) 221 return None 222 223 def _check_inputs( 224 self, request: Request, body: bytes, module_key: Ss58Address 225 ): 226 required_headers = ["x-signature", "x-key", "x-crypto"] 227 optional_headers = ["x-timestamp"] 228 229 # TODO: we'll replace this by a Result ADT :) 230 match self._get_headers_dict( 231 request.headers, required_headers, optional_headers 232 ): 233 case (False, error): 234 return (False, error) 235 case (True, headers_dict): 236 pass 237 238 # TODO: we'll replace this by a Result ADT :) 239 match self._check_signature(headers_dict, body, module_key): 240 case (False, error): 241 return (False, error) 242 case (True, _): 243 pass 244 245 # TODO: we'll replace this by a Result ADT :) 246 match self._check_key_registered( 247 self.subnets_whitelist, 248 headers_dict, 249 self.blockchain_cache, 250 self.host_key, 251 self.use_testnet, 252 ): 253 case (False, error): 254 return (False, error) 255 case (True, _): 256 pass 257 258 return (True, None) 259 260 def _get_headers_dict( 261 self, 262 headers: starlette.datastructures.Headers, 263 required: list[str], 264 optional: list[str], 265 ): 266 headers_dict: dict[str, str] = {} 267 for required_header in required: 268 value = headers.get(required_header) 269 if not value: 270 code = 400 271 return False, json_error( 272 code, f"Missing header: {required_header}" 273 ) 274 headers_dict[required_header] = value 275 for optional_header in optional: 276 value = headers.get(optional_header) 277 if value: 278 headers_dict[optional_header] = value 279 280 return True, headers_dict 281 282 def _check_signature( 283 self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address 284 ): 285 key = headers_dict["x-key"] 286 signature = headers_dict["x-signature"] 287 crypto = int(headers_dict["x-crypto"]) # TODO: better handling of this 288 289 if not is_hex_string(key): 290 reason = "X-Key should be a hex value" 291 log_reffusal(key, reason) 292 return (False, json_error(400, reason)) 293 try: 294 signature = parse_hex(signature) 295 except Exception: 296 reason = "Signature sent is not a valid hex value" 297 log_reffusal(key, reason) 298 return False, json_error(400, reason) 299 try: 300 key = parse_hex(key) 301 except Exception: 302 reason = "Key sent is not a valid hex value" 303 log_reffusal(key, reason) 304 return False, json_error(400, reason) 305 # decodes the key for better logging 306 key_ss58 = try_ss58_decode(key) 307 if key_ss58 is None: 308 reason = "Caller key could not be decoded into a ss58address" 309 log_reffusal(key.decode(), reason) 310 return (False, json_error(400, reason)) 311 312 timestamp = headers_dict.get("x-timestamp") 313 legacy_verified = False 314 if timestamp: 315 # tries to do a legacy verification 316 json_body = json.loads(body) 317 json_body["timestamp"] = timestamp 318 stamped_body = json.dumps(json_body).encode() 319 legacy_verified = signer.verify( 320 key, crypto, stamped_body, signature 321 ) 322 323 verified = signer.verify(key, crypto, body, signature) 324 if not verified and not legacy_verified: 325 reason = "Signature doesn't match" 326 log_reffusal(key_ss58, reason) 327 return (False, json_error(401, "Signatures doesn't match")) 328 329 body_dict: dict[str, dict[str, Any]] = json.loads(body) 330 target_key = body_dict["params"].get("target_key", None) 331 if not target_key or target_key != module_key: 332 reason = "Wrong target_key in body" 333 log_reffusal(key_ss58, reason) 334 return (False, json_error(401, "Wrong target_key in body")) 335 336 return (True, None) 337 338 def _check_key_registered( 339 self, 340 subnets_whitelist: list[int] | None, 341 headers_dict: dict[str, str], 342 blockchain_cache: TTLDict[str, list[Ss58Address]], 343 host_key: Keypair, 344 use_testnet: bool, 345 ): 346 key = headers_dict["x-key"] 347 if not is_hex_string(key): 348 return (False, json_error(400, "X-Key should be a hex value")) 349 key = parse_hex(key) 350 351 # TODO: checking for key being registered should be smarter 352 # e.g. query and store all registered modules periodically. 353 354 ss58 = try_ss58_decode(key) 355 if ss58 is None: 356 reason = "Caller key could not be decoded into a ss58address" 357 log_reffusal(key.decode(), reason) 358 return (False, json_error(400, reason)) 359 360 # If subnets whitelist is specified, checks if key is registered in one 361 # of the given subnets 362 363 allowed_subnets: dict[int, bool] = {} 364 caller_subnets: list[int] = [] 365 if subnets_whitelist is not None: 366 367 def query_keys(subnet: int): 368 try: 369 node_url = get_node_url(None, use_testnet=use_testnet) 370 client = make_client( 371 node_url 372 ) # TODO: get client from outer context 373 return [*client.query_map_key(subnet).values()] 374 except Exception: 375 log("WARNING: Could not connect to a blockchain node") 376 return_list: list[Ss58Address] = [] 377 return return_list 378 379 # TODO: client pool for entire module server 380 381 got_keys = False 382 no_keys_reason = ( 383 "Miner could not connect to a blockchain node " 384 "or there is no key registered on the subnet(s) {} " 385 ) 386 for subnet in subnets_whitelist: 387 get_keys_on_subnet = partial(query_keys, subnet) 388 cache_key = f"keys_on_subnet_{subnet}" 389 keys_on_subnet = blockchain_cache.get_or_insert_lazy( 390 cache_key, get_keys_on_subnet 391 ) 392 if len(keys_on_subnet) == 0: 393 reason = no_keys_reason.format(subnet) 394 log(f"WARNING: {reason}") 395 else: 396 got_keys = True 397 if host_key.ss58_address not in keys_on_subnet: 398 log( 399 f"WARNING: This miner is deregistered on subnet {subnet}" 400 ) 401 else: 402 allowed_subnets[subnet] = True 403 if ss58 in keys_on_subnet: 404 caller_subnets.append(subnet) 405 if not got_keys: 406 return False, json_error( 407 503, no_keys_reason.format(subnets_whitelist) 408 ) 409 if not allowed_subnets: 410 log("WARNING: Miner is not registered on any subnet") 411 return False, json_error( 412 403, "Miner is not registered on any subnet" 413 ) 414 415 # searches for a common subnet between caller and miner 416 # TODO: use sets 417 allowed_subnets = { 418 subnet: allowed 419 for subnet, allowed in allowed_subnets.items() 420 if (subnet in caller_subnets) 421 } 422 if not allowed_subnets: 423 reason = "Caller key is not registered in any subnet that the miner is" 424 log_reffusal(ss58, reason) 425 return False, json_error(403, reason) 426 else: 427 # accepts everything 428 pass 429 430 return (True, None)
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
176 def __init__( 177 self, 178 subnets_whitelist: list[int] | None, 179 module_key: Ss58Address, 180 request_staleness: int, 181 blockchain_cache: TTLDict[str, list[Ss58Address]], 182 host_key: Keypair, 183 use_testnet: bool, 184 ): 185 self.subnets_whitelist = subnets_whitelist 186 self.module_key = module_key 187 self.request_staleness = request_staleness 188 self.blockchain_cache = blockchain_cache 189 self.host_key = host_key 190 self.use_testnet = use_testnet
192 async def verify(self, request: Request): 193 body = await request.body() 194 195 # TODO: we'll replace this by a Result ADT :) 196 match self._check_inputs(request, body, self.module_key): 197 case (False, error): 198 return error 199 case (True, _): 200 pass 201 202 body_dict: dict[str, dict[str, Any]] = json.loads(body) 203 timestamp = body_dict["params"].get("timestamp", None) 204 legacy_timestamp = request.headers.get("X-Timestamp", None) 205 try: 206 timestamp_to_use = ( 207 timestamp if not legacy_timestamp else legacy_timestamp 208 ) 209 request_time = datetime.fromisoformat(timestamp_to_use) 210 except Exception: 211 return JSONResponse( 212 status_code=400, 213 content={"error": "Invalid ISO timestamp given"}, 214 ) 215 if ( 216 datetime.now(timezone.utc) - request_time 217 ).total_seconds() > self.request_staleness: 218 return JSONResponse( 219 status_code=400, content={"error": "Request is too stale"} 220 ) 221 return None
Please dont mutate the request D:
433def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]: 434 class CheckListsRoute(APIRoute): 435 def get_route_handler(self): 436 original_route_handler = super().get_route_handler() 437 438 async def custom_route_handler( 439 request: Request, 440 ) -> Response | JSONResponse: 441 if not request.url.path.startswith("/method"): 442 unhandled_response: Response = await original_route_handler( 443 request 444 ) 445 return unhandled_response 446 for verifier in verifiers: 447 response = await verifier.verify(request) 448 if response is not None: 449 return response 450 451 original_response: Response = await original_route_handler( 452 request 453 ) 454 return original_response 455 456 return custom_route_handler 457 458 return CheckListsRoute