Edit on GitHub

communex.module.routers.module_routers

  1import json
  2import re
  3from abc import abstractmethod
  4from datetime import datetime, timezone
  5from functools import partial
  6from typing import Any, Protocol, Sequence
  7
  8import starlette.datastructures
  9from fastapi import Request, Response
 10from fastapi.responses import JSONResponse
 11from fastapi.routing import APIRoute
 12from keylimiter import TokenBucketLimiter
 13from substrateinterface import Keypair
 14
 15from communex._common import get_node_url
 16from communex.module import _signer as signer
 17from communex.module._rate_limiters._stake_limiter import StakeLimiter
 18from communex.module._rate_limiters.limiters import (
 19    IpLimiterParams,
 20    StakeLimiterParams,
 21)
 22from communex.module._util import (
 23    json_error,
 24    log,
 25    log_reffusal,
 26    make_client,
 27    try_ss58_decode,
 28)
 29from communex.types import Ss58Address
 30from communex.util.memo import TTLDict
 31
 32HEX_PATTERN = re.compile(r"^[0-9a-fA-F]+$")
 33
 34
 35def is_hex_string(string: str):
 36    return bool(HEX_PATTERN.match(string))
 37
 38
 39def parse_hex(hex_str: str) -> bytes:
 40    if hex_str[0:2] == "0x":
 41        return bytes.fromhex(hex_str[2:])
 42    else:
 43        return bytes.fromhex(hex_str)
 44
 45
 46class AbstractVerifier(Protocol):
 47    @abstractmethod
 48    async def verify(self, request: Request) -> JSONResponse | None:
 49        """Please dont mutate the request D:"""
 50        ...
 51
 52
 53class StakeLimiterVerifier(AbstractVerifier):
 54    def __init__(
 55        self,
 56        subnets_whitelist: list[int] | None,
 57        params_: StakeLimiterParams | None,
 58    ):
 59        self.subnets_whitelist = subnets_whitelist
 60        self.params_ = params_
 61        if not self.params_:
 62            params = StakeLimiterParams()
 63        else:
 64            params = self.params_
 65        self.limiter = StakeLimiter(
 66            self.subnets_whitelist,
 67            epoch=params.epoch,
 68            max_cache_age=params.cache_age,
 69            get_refill_rate=params.get_refill_per_epoch,
 70        )
 71
 72    async def verify(self, request: Request):
 73        if request.client is None:
 74            response = JSONResponse(
 75                status_code=401,
 76                content={"error": "Address should be present in request"},
 77            )
 78            return response
 79
 80        key = request.headers.get("x-key")
 81        if not key:
 82            response = JSONResponse(
 83                status_code=401,
 84                content={"error": "Valid X-Key not provided on headers"},
 85            )
 86            return response
 87
 88        is_allowed = await self.limiter.allow(key)
 89
 90        if not is_allowed:
 91            response = JSONResponse(
 92                status_code=429,
 93                headers={
 94                    "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"
 95                },
 96                content={"error": "Rate limit exceeded"},
 97            )
 98            return response
 99        return None
100
101
102class ListVerifier(AbstractVerifier):
103    def __init__(
104        self,
105        blacklist: list[Ss58Address] | None,
106        whitelist: list[Ss58Address] | None,
107        ip_blacklist: list[str] | None,
108    ):
109        self.blacklist = blacklist
110        self.whitelist = whitelist
111        self.ip_blacklist = ip_blacklist
112
113    async def verify(self, request: Request) -> JSONResponse | None:
114        key = request.headers.get("x-key")
115        if not key:
116            reason = "Missing header: X-Key"
117            log(f"INFO: refusing module request because: {reason}")
118            return json_error(400, "Missing header: X-Key")
119
120        ss58 = try_ss58_decode(key)
121        if ss58 is None:
122            reason = "Caller key could not be decoded into a ss58address"
123            log_reffusal(key, reason)
124            return json_error(400, reason)
125        if request.client is None:
126            return json_error(400, "Address should be present in request")
127        if self.blacklist and ss58 in self.blacklist:
128            return json_error(403, "You are blacklisted")
129        if self.ip_blacklist and request.client.host in self.ip_blacklist:
130            return json_error(403, "Your IP is blacklisted")
131        if self.whitelist and ss58 not in self.whitelist:
132            return json_error(403, "You are not whitelisted")
133        return None
134
135
136class IpLimiterVerifier(AbstractVerifier):
137    def __init__(
138        self,
139        params: IpLimiterParams | None,
140    ):
141        """
142        :param limiter: KeyLimiter instance OR None
143
144        If limiter is None, then a default TokenBucketLimiter is used with the following config:
145        bucket_size=200, refill_rate=15
146        """
147
148        # fallback to default limiter
149        if not params:
150            params = IpLimiterParams()
151        self._limiter = TokenBucketLimiter(
152            bucket_size=params.bucket_size, refill_rate=params.refill_rate
153        )
154
155    async def verify(self, request: Request):
156        assert request.client is not None, "request is invalid"
157        assert request.client.host, "request is invalid."
158
159        ip = request.client.host
160
161        is_allowed = self._limiter.allow(ip)
162
163        if not is_allowed:
164            response = JSONResponse(
165                status_code=429,
166                headers={
167                    "X-RateLimit-Remaining": str(self._limiter.remaining(ip))
168                },
169                content={"error": "Rate limit exceeded"},
170            )
171            return response
172        return None
173
174
175class InputHandlerVerifier(AbstractVerifier):
176    def __init__(
177        self,
178        subnets_whitelist: list[int] | None,
179        module_key: Ss58Address,
180        request_staleness: int,
181        blockchain_cache: TTLDict[str, list[Ss58Address]],
182        host_key: Keypair,
183        use_testnet: bool,
184    ):
185        self.subnets_whitelist = subnets_whitelist
186        self.module_key = module_key
187        self.request_staleness = request_staleness
188        self.blockchain_cache = blockchain_cache
189        self.host_key = host_key
190        self.use_testnet = use_testnet
191
192    async def verify(self, request: Request):
193        body = await request.body()
194
195        # TODO: we'll replace this by a Result ADT :)
196        match self._check_inputs(request, body, self.module_key):
197            case (False, error):
198                return error
199            case (True, _):
200                pass
201
202        body_dict: dict[str, dict[str, Any]] = json.loads(body)
203        timestamp = body_dict["params"].get("timestamp", None)
204        legacy_timestamp = request.headers.get("X-Timestamp", None)
205        try:
206            timestamp_to_use = (
207                timestamp if not legacy_timestamp else legacy_timestamp
208            )
209            request_time = datetime.fromisoformat(timestamp_to_use)
210        except Exception:
211            return JSONResponse(
212                status_code=400,
213                content={"error": "Invalid ISO timestamp given"},
214            )
215        if (
216            datetime.now(timezone.utc) - request_time
217        ).total_seconds() > self.request_staleness:
218            return JSONResponse(
219                status_code=400, content={"error": "Request is too stale"}
220            )
221        return None
222
223    def _check_inputs(
224        self, request: Request, body: bytes, module_key: Ss58Address
225    ):
226        required_headers = ["x-signature", "x-key", "x-crypto"]
227        optional_headers = ["x-timestamp"]
228
229        # TODO: we'll replace this by a Result ADT :)
230        match self._get_headers_dict(
231            request.headers, required_headers, optional_headers
232        ):
233            case (False, error):
234                return (False, error)
235            case (True, headers_dict):
236                pass
237
238        # TODO: we'll replace this by a Result ADT :)
239        match self._check_signature(headers_dict, body, module_key):
240            case (False, error):
241                return (False, error)
242            case (True, _):
243                pass
244
245        # TODO: we'll replace this by a Result ADT :)
246        match self._check_key_registered(
247            self.subnets_whitelist,
248            headers_dict,
249            self.blockchain_cache,
250            self.host_key,
251            self.use_testnet,
252        ):
253            case (False, error):
254                return (False, error)
255            case (True, _):
256                pass
257
258        return (True, None)
259
260    def _get_headers_dict(
261        self,
262        headers: starlette.datastructures.Headers,
263        required: list[str],
264        optional: list[str],
265    ):
266        headers_dict: dict[str, str] = {}
267        for required_header in required:
268            value = headers.get(required_header)
269            if not value:
270                code = 400
271                return False, json_error(
272                    code, f"Missing header: {required_header}"
273                )
274            headers_dict[required_header] = value
275        for optional_header in optional:
276            value = headers.get(optional_header)
277            if value:
278                headers_dict[optional_header] = value
279
280        return True, headers_dict
281
282    def _check_signature(
283        self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address
284    ):
285        key = headers_dict["x-key"]
286        signature = headers_dict["x-signature"]
287        crypto = int(headers_dict["x-crypto"])  # TODO: better handling of this
288
289        if not is_hex_string(key):
290            reason = "X-Key should be a hex value"
291            log_reffusal(key, reason)
292            return (False, json_error(400, reason))
293        try:
294            signature = parse_hex(signature)
295        except Exception:
296            reason = "Signature sent is not a valid hex value"
297            log_reffusal(key, reason)
298            return False, json_error(400, reason)
299        try:
300            key = parse_hex(key)
301        except Exception:
302            reason = "Key sent is not a valid hex value"
303            log_reffusal(key, reason)
304            return False, json_error(400, reason)
305        # decodes the key for better logging
306        key_ss58 = try_ss58_decode(key)
307        if key_ss58 is None:
308            reason = "Caller key could not be decoded into a ss58address"
309            log_reffusal(key.decode(), reason)
310            return (False, json_error(400, reason))
311
312        timestamp = headers_dict.get("x-timestamp")
313        legacy_verified = False
314        if timestamp:
315            # tries to do a legacy verification
316            json_body = json.loads(body)
317            json_body["timestamp"] = timestamp
318            stamped_body = json.dumps(json_body).encode()
319            legacy_verified = signer.verify(
320                key, crypto, stamped_body, signature
321            )
322
323        verified = signer.verify(key, crypto, body, signature)
324        if not verified and not legacy_verified:
325            reason = "Signature doesn't match"
326            log_reffusal(key_ss58, reason)
327            return (False, json_error(401, "Signatures doesn't match"))
328
329        body_dict: dict[str, dict[str, Any]] = json.loads(body)
330        target_key = body_dict["params"].get("target_key", None)
331        if not target_key or target_key != module_key:
332            reason = "Wrong target_key in body"
333            log_reffusal(key_ss58, reason)
334            return (False, json_error(401, "Wrong target_key in body"))
335
336        return (True, None)
337
338    def _check_key_registered(
339        self,
340        subnets_whitelist: list[int] | None,
341        headers_dict: dict[str, str],
342        blockchain_cache: TTLDict[str, list[Ss58Address]],
343        host_key: Keypair,
344        use_testnet: bool,
345    ):
346        key = headers_dict["x-key"]
347        if not is_hex_string(key):
348            return (False, json_error(400, "X-Key should be a hex value"))
349        key = parse_hex(key)
350
351        # TODO: checking for key being registered should be smarter
352        # e.g. query and store all registered modules periodically.
353
354        ss58 = try_ss58_decode(key)
355        if ss58 is None:
356            reason = "Caller key could not be decoded into a ss58address"
357            log_reffusal(key.decode(), reason)
358            return (False, json_error(400, reason))
359
360        # If subnets whitelist is specified, checks if key is registered in one
361        # of the given subnets
362
363        allowed_subnets: dict[int, bool] = {}
364        caller_subnets: list[int] = []
365        if subnets_whitelist is not None:
366
367            def query_keys(subnet: int):
368                try:
369                    node_url = get_node_url(None, use_testnet=use_testnet)
370                    client = make_client(
371                        node_url
372                    )  # TODO: get client from outer context
373                    return [*client.query_map_key(subnet).values()]
374                except Exception:
375                    log("WARNING: Could not connect to a blockchain node")
376                    return_list: list[Ss58Address] = []
377                    return return_list
378
379            # TODO: client pool for entire module server
380
381            got_keys = False
382            no_keys_reason = (
383                "Miner could not connect to a blockchain node "
384                "or there is no key registered on the subnet(s) {} "
385            )
386            for subnet in subnets_whitelist:
387                get_keys_on_subnet = partial(query_keys, subnet)
388                cache_key = f"keys_on_subnet_{subnet}"
389                keys_on_subnet = blockchain_cache.get_or_insert_lazy(
390                    cache_key, get_keys_on_subnet
391                )
392                if len(keys_on_subnet) == 0:
393                    reason = no_keys_reason.format(subnet)
394                    log(f"WARNING: {reason}")
395                else:
396                    got_keys = True
397                if host_key.ss58_address not in keys_on_subnet:
398                    log(
399                        f"WARNING: This miner is deregistered on subnet {subnet}"
400                    )
401                else:
402                    allowed_subnets[subnet] = True
403                if ss58 in keys_on_subnet:
404                    caller_subnets.append(subnet)
405            if not got_keys:
406                return False, json_error(
407                    503, no_keys_reason.format(subnets_whitelist)
408                )
409            if not allowed_subnets:
410                log("WARNING: Miner is not registered on any subnet")
411                return False, json_error(
412                    403, "Miner is not registered on any subnet"
413                )
414
415            # searches for a common subnet between caller and miner
416            # TODO: use sets
417            allowed_subnets = {
418                subnet: allowed
419                for subnet, allowed in allowed_subnets.items()
420                if (subnet in caller_subnets)
421            }
422            if not allowed_subnets:
423                reason = "Caller key is not registered in any subnet that the miner is"
424                log_reffusal(ss58, reason)
425                return False, json_error(403, reason)
426        else:
427            # accepts everything
428            pass
429
430        return (True, None)
431
432
433def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]:
434    class CheckListsRoute(APIRoute):
435        def get_route_handler(self):
436            original_route_handler = super().get_route_handler()
437
438            async def custom_route_handler(
439                request: Request,
440            ) -> Response | JSONResponse:
441                if not request.url.path.startswith("/method"):
442                    unhandled_response: Response = await original_route_handler(
443                        request
444                    )
445                    return unhandled_response
446                for verifier in verifiers:
447                    response = await verifier.verify(request)
448                    if response is not None:
449                        return response
450
451                original_response: Response = await original_route_handler(
452                    request
453                )
454                return original_response
455
456            return custom_route_handler
457
458    return CheckListsRoute
HEX_PATTERN = re.compile('^[0-9a-fA-F]+$')
def is_hex_string(string: str):
36def is_hex_string(string: str):
37    return bool(HEX_PATTERN.match(string))
def parse_hex(hex_str: str) -> bytes:
40def parse_hex(hex_str: str) -> bytes:
41    if hex_str[0:2] == "0x":
42        return bytes.fromhex(hex_str[2:])
43    else:
44        return bytes.fromhex(hex_str)
class AbstractVerifier(typing.Protocol):
47class AbstractVerifier(Protocol):
48    @abstractmethod
49    async def verify(self, request: Request) -> JSONResponse | None:
50        """Please dont mutate the request D:"""
51        ...

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
@abstractmethod
async def verify( self, request: starlette.requests.Request) -> starlette.responses.JSONResponse | None:
48    @abstractmethod
49    async def verify(self, request: Request) -> JSONResponse | None:
50        """Please dont mutate the request D:"""
51        ...

Please dont mutate the request D:

class StakeLimiterVerifier(AbstractVerifier):
 54class StakeLimiterVerifier(AbstractVerifier):
 55    def __init__(
 56        self,
 57        subnets_whitelist: list[int] | None,
 58        params_: StakeLimiterParams | None,
 59    ):
 60        self.subnets_whitelist = subnets_whitelist
 61        self.params_ = params_
 62        if not self.params_:
 63            params = StakeLimiterParams()
 64        else:
 65            params = self.params_
 66        self.limiter = StakeLimiter(
 67            self.subnets_whitelist,
 68            epoch=params.epoch,
 69            max_cache_age=params.cache_age,
 70            get_refill_rate=params.get_refill_per_epoch,
 71        )
 72
 73    async def verify(self, request: Request):
 74        if request.client is None:
 75            response = JSONResponse(
 76                status_code=401,
 77                content={"error": "Address should be present in request"},
 78            )
 79            return response
 80
 81        key = request.headers.get("x-key")
 82        if not key:
 83            response = JSONResponse(
 84                status_code=401,
 85                content={"error": "Valid X-Key not provided on headers"},
 86            )
 87            return response
 88
 89        is_allowed = await self.limiter.allow(key)
 90
 91        if not is_allowed:
 92            response = JSONResponse(
 93                status_code=429,
 94                headers={
 95                    "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"
 96                },
 97                content={"error": "Rate limit exceeded"},
 98            )
 99            return response
100        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
StakeLimiterVerifier( subnets_whitelist: list[int] | None, params_: communex.module._rate_limiters.limiters.StakeLimiterParams | None)
55    def __init__(
56        self,
57        subnets_whitelist: list[int] | None,
58        params_: StakeLimiterParams | None,
59    ):
60        self.subnets_whitelist = subnets_whitelist
61        self.params_ = params_
62        if not self.params_:
63            params = StakeLimiterParams()
64        else:
65            params = self.params_
66        self.limiter = StakeLimiter(
67            self.subnets_whitelist,
68            epoch=params.epoch,
69            max_cache_age=params.cache_age,
70            get_refill_rate=params.get_refill_per_epoch,
71        )
subnets_whitelist
params_
limiter
async def verify(self, request: starlette.requests.Request):
 73    async def verify(self, request: Request):
 74        if request.client is None:
 75            response = JSONResponse(
 76                status_code=401,
 77                content={"error": "Address should be present in request"},
 78            )
 79            return response
 80
 81        key = request.headers.get("x-key")
 82        if not key:
 83            response = JSONResponse(
 84                status_code=401,
 85                content={"error": "Valid X-Key not provided on headers"},
 86            )
 87            return response
 88
 89        is_allowed = await self.limiter.allow(key)
 90
 91        if not is_allowed:
 92            response = JSONResponse(
 93                status_code=429,
 94                headers={
 95                    "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"
 96                },
 97                content={"error": "Rate limit exceeded"},
 98            )
 99            return response
100        return None

Please dont mutate the request D:

class ListVerifier(AbstractVerifier):
103class ListVerifier(AbstractVerifier):
104    def __init__(
105        self,
106        blacklist: list[Ss58Address] | None,
107        whitelist: list[Ss58Address] | None,
108        ip_blacklist: list[str] | None,
109    ):
110        self.blacklist = blacklist
111        self.whitelist = whitelist
112        self.ip_blacklist = ip_blacklist
113
114    async def verify(self, request: Request) -> JSONResponse | None:
115        key = request.headers.get("x-key")
116        if not key:
117            reason = "Missing header: X-Key"
118            log(f"INFO: refusing module request because: {reason}")
119            return json_error(400, "Missing header: X-Key")
120
121        ss58 = try_ss58_decode(key)
122        if ss58 is None:
123            reason = "Caller key could not be decoded into a ss58address"
124            log_reffusal(key, reason)
125            return json_error(400, reason)
126        if request.client is None:
127            return json_error(400, "Address should be present in request")
128        if self.blacklist and ss58 in self.blacklist:
129            return json_error(403, "You are blacklisted")
130        if self.ip_blacklist and request.client.host in self.ip_blacklist:
131            return json_error(403, "Your IP is blacklisted")
132        if self.whitelist and ss58 not in self.whitelist:
133            return json_error(403, "You are not whitelisted")
134        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
ListVerifier( blacklist: list[communex.types.Ss58Address] | None, whitelist: list[communex.types.Ss58Address] | None, ip_blacklist: list[str] | None)
104    def __init__(
105        self,
106        blacklist: list[Ss58Address] | None,
107        whitelist: list[Ss58Address] | None,
108        ip_blacklist: list[str] | None,
109    ):
110        self.blacklist = blacklist
111        self.whitelist = whitelist
112        self.ip_blacklist = ip_blacklist
blacklist
whitelist
ip_blacklist
async def verify( self, request: starlette.requests.Request) -> starlette.responses.JSONResponse | None:
114    async def verify(self, request: Request) -> JSONResponse | None:
115        key = request.headers.get("x-key")
116        if not key:
117            reason = "Missing header: X-Key"
118            log(f"INFO: refusing module request because: {reason}")
119            return json_error(400, "Missing header: X-Key")
120
121        ss58 = try_ss58_decode(key)
122        if ss58 is None:
123            reason = "Caller key could not be decoded into a ss58address"
124            log_reffusal(key, reason)
125            return json_error(400, reason)
126        if request.client is None:
127            return json_error(400, "Address should be present in request")
128        if self.blacklist and ss58 in self.blacklist:
129            return json_error(403, "You are blacklisted")
130        if self.ip_blacklist and request.client.host in self.ip_blacklist:
131            return json_error(403, "Your IP is blacklisted")
132        if self.whitelist and ss58 not in self.whitelist:
133            return json_error(403, "You are not whitelisted")
134        return None

Please dont mutate the request D:

class IpLimiterVerifier(AbstractVerifier):
137class IpLimiterVerifier(AbstractVerifier):
138    def __init__(
139        self,
140        params: IpLimiterParams | None,
141    ):
142        """
143        :param limiter: KeyLimiter instance OR None
144
145        If limiter is None, then a default TokenBucketLimiter is used with the following config:
146        bucket_size=200, refill_rate=15
147        """
148
149        # fallback to default limiter
150        if not params:
151            params = IpLimiterParams()
152        self._limiter = TokenBucketLimiter(
153            bucket_size=params.bucket_size, refill_rate=params.refill_rate
154        )
155
156    async def verify(self, request: Request):
157        assert request.client is not None, "request is invalid"
158        assert request.client.host, "request is invalid."
159
160        ip = request.client.host
161
162        is_allowed = self._limiter.allow(ip)
163
164        if not is_allowed:
165            response = JSONResponse(
166                status_code=429,
167                headers={
168                    "X-RateLimit-Remaining": str(self._limiter.remaining(ip))
169                },
170                content={"error": "Rate limit exceeded"},
171            )
172            return response
173        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
IpLimiterVerifier( params: communex.module._rate_limiters.limiters.IpLimiterParams | None)
138    def __init__(
139        self,
140        params: IpLimiterParams | None,
141    ):
142        """
143        :param limiter: KeyLimiter instance OR None
144
145        If limiter is None, then a default TokenBucketLimiter is used with the following config:
146        bucket_size=200, refill_rate=15
147        """
148
149        # fallback to default limiter
150        if not params:
151            params = IpLimiterParams()
152        self._limiter = TokenBucketLimiter(
153            bucket_size=params.bucket_size, refill_rate=params.refill_rate
154        )
Parameters
  • limiter: KeyLimiter instance OR None

If limiter is None, then a default TokenBucketLimiter is used with the following config: bucket_size=200, refill_rate=15

async def verify(self, request: starlette.requests.Request):
156    async def verify(self, request: Request):
157        assert request.client is not None, "request is invalid"
158        assert request.client.host, "request is invalid."
159
160        ip = request.client.host
161
162        is_allowed = self._limiter.allow(ip)
163
164        if not is_allowed:
165            response = JSONResponse(
166                status_code=429,
167                headers={
168                    "X-RateLimit-Remaining": str(self._limiter.remaining(ip))
169                },
170                content={"error": "Rate limit exceeded"},
171            )
172            return response
173        return None

Please dont mutate the request D:

class InputHandlerVerifier(AbstractVerifier):
176class InputHandlerVerifier(AbstractVerifier):
177    def __init__(
178        self,
179        subnets_whitelist: list[int] | None,
180        module_key: Ss58Address,
181        request_staleness: int,
182        blockchain_cache: TTLDict[str, list[Ss58Address]],
183        host_key: Keypair,
184        use_testnet: bool,
185    ):
186        self.subnets_whitelist = subnets_whitelist
187        self.module_key = module_key
188        self.request_staleness = request_staleness
189        self.blockchain_cache = blockchain_cache
190        self.host_key = host_key
191        self.use_testnet = use_testnet
192
193    async def verify(self, request: Request):
194        body = await request.body()
195
196        # TODO: we'll replace this by a Result ADT :)
197        match self._check_inputs(request, body, self.module_key):
198            case (False, error):
199                return error
200            case (True, _):
201                pass
202
203        body_dict: dict[str, dict[str, Any]] = json.loads(body)
204        timestamp = body_dict["params"].get("timestamp", None)
205        legacy_timestamp = request.headers.get("X-Timestamp", None)
206        try:
207            timestamp_to_use = (
208                timestamp if not legacy_timestamp else legacy_timestamp
209            )
210            request_time = datetime.fromisoformat(timestamp_to_use)
211        except Exception:
212            return JSONResponse(
213                status_code=400,
214                content={"error": "Invalid ISO timestamp given"},
215            )
216        if (
217            datetime.now(timezone.utc) - request_time
218        ).total_seconds() > self.request_staleness:
219            return JSONResponse(
220                status_code=400, content={"error": "Request is too stale"}
221            )
222        return None
223
224    def _check_inputs(
225        self, request: Request, body: bytes, module_key: Ss58Address
226    ):
227        required_headers = ["x-signature", "x-key", "x-crypto"]
228        optional_headers = ["x-timestamp"]
229
230        # TODO: we'll replace this by a Result ADT :)
231        match self._get_headers_dict(
232            request.headers, required_headers, optional_headers
233        ):
234            case (False, error):
235                return (False, error)
236            case (True, headers_dict):
237                pass
238
239        # TODO: we'll replace this by a Result ADT :)
240        match self._check_signature(headers_dict, body, module_key):
241            case (False, error):
242                return (False, error)
243            case (True, _):
244                pass
245
246        # TODO: we'll replace this by a Result ADT :)
247        match self._check_key_registered(
248            self.subnets_whitelist,
249            headers_dict,
250            self.blockchain_cache,
251            self.host_key,
252            self.use_testnet,
253        ):
254            case (False, error):
255                return (False, error)
256            case (True, _):
257                pass
258
259        return (True, None)
260
261    def _get_headers_dict(
262        self,
263        headers: starlette.datastructures.Headers,
264        required: list[str],
265        optional: list[str],
266    ):
267        headers_dict: dict[str, str] = {}
268        for required_header in required:
269            value = headers.get(required_header)
270            if not value:
271                code = 400
272                return False, json_error(
273                    code, f"Missing header: {required_header}"
274                )
275            headers_dict[required_header] = value
276        for optional_header in optional:
277            value = headers.get(optional_header)
278            if value:
279                headers_dict[optional_header] = value
280
281        return True, headers_dict
282
283    def _check_signature(
284        self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address
285    ):
286        key = headers_dict["x-key"]
287        signature = headers_dict["x-signature"]
288        crypto = int(headers_dict["x-crypto"])  # TODO: better handling of this
289
290        if not is_hex_string(key):
291            reason = "X-Key should be a hex value"
292            log_reffusal(key, reason)
293            return (False, json_error(400, reason))
294        try:
295            signature = parse_hex(signature)
296        except Exception:
297            reason = "Signature sent is not a valid hex value"
298            log_reffusal(key, reason)
299            return False, json_error(400, reason)
300        try:
301            key = parse_hex(key)
302        except Exception:
303            reason = "Key sent is not a valid hex value"
304            log_reffusal(key, reason)
305            return False, json_error(400, reason)
306        # decodes the key for better logging
307        key_ss58 = try_ss58_decode(key)
308        if key_ss58 is None:
309            reason = "Caller key could not be decoded into a ss58address"
310            log_reffusal(key.decode(), reason)
311            return (False, json_error(400, reason))
312
313        timestamp = headers_dict.get("x-timestamp")
314        legacy_verified = False
315        if timestamp:
316            # tries to do a legacy verification
317            json_body = json.loads(body)
318            json_body["timestamp"] = timestamp
319            stamped_body = json.dumps(json_body).encode()
320            legacy_verified = signer.verify(
321                key, crypto, stamped_body, signature
322            )
323
324        verified = signer.verify(key, crypto, body, signature)
325        if not verified and not legacy_verified:
326            reason = "Signature doesn't match"
327            log_reffusal(key_ss58, reason)
328            return (False, json_error(401, "Signatures doesn't match"))
329
330        body_dict: dict[str, dict[str, Any]] = json.loads(body)
331        target_key = body_dict["params"].get("target_key", None)
332        if not target_key or target_key != module_key:
333            reason = "Wrong target_key in body"
334            log_reffusal(key_ss58, reason)
335            return (False, json_error(401, "Wrong target_key in body"))
336
337        return (True, None)
338
339    def _check_key_registered(
340        self,
341        subnets_whitelist: list[int] | None,
342        headers_dict: dict[str, str],
343        blockchain_cache: TTLDict[str, list[Ss58Address]],
344        host_key: Keypair,
345        use_testnet: bool,
346    ):
347        key = headers_dict["x-key"]
348        if not is_hex_string(key):
349            return (False, json_error(400, "X-Key should be a hex value"))
350        key = parse_hex(key)
351
352        # TODO: checking for key being registered should be smarter
353        # e.g. query and store all registered modules periodically.
354
355        ss58 = try_ss58_decode(key)
356        if ss58 is None:
357            reason = "Caller key could not be decoded into a ss58address"
358            log_reffusal(key.decode(), reason)
359            return (False, json_error(400, reason))
360
361        # If subnets whitelist is specified, checks if key is registered in one
362        # of the given subnets
363
364        allowed_subnets: dict[int, bool] = {}
365        caller_subnets: list[int] = []
366        if subnets_whitelist is not None:
367
368            def query_keys(subnet: int):
369                try:
370                    node_url = get_node_url(None, use_testnet=use_testnet)
371                    client = make_client(
372                        node_url
373                    )  # TODO: get client from outer context
374                    return [*client.query_map_key(subnet).values()]
375                except Exception:
376                    log("WARNING: Could not connect to a blockchain node")
377                    return_list: list[Ss58Address] = []
378                    return return_list
379
380            # TODO: client pool for entire module server
381
382            got_keys = False
383            no_keys_reason = (
384                "Miner could not connect to a blockchain node "
385                "or there is no key registered on the subnet(s) {} "
386            )
387            for subnet in subnets_whitelist:
388                get_keys_on_subnet = partial(query_keys, subnet)
389                cache_key = f"keys_on_subnet_{subnet}"
390                keys_on_subnet = blockchain_cache.get_or_insert_lazy(
391                    cache_key, get_keys_on_subnet
392                )
393                if len(keys_on_subnet) == 0:
394                    reason = no_keys_reason.format(subnet)
395                    log(f"WARNING: {reason}")
396                else:
397                    got_keys = True
398                if host_key.ss58_address not in keys_on_subnet:
399                    log(
400                        f"WARNING: This miner is deregistered on subnet {subnet}"
401                    )
402                else:
403                    allowed_subnets[subnet] = True
404                if ss58 in keys_on_subnet:
405                    caller_subnets.append(subnet)
406            if not got_keys:
407                return False, json_error(
408                    503, no_keys_reason.format(subnets_whitelist)
409                )
410            if not allowed_subnets:
411                log("WARNING: Miner is not registered on any subnet")
412                return False, json_error(
413                    403, "Miner is not registered on any subnet"
414                )
415
416            # searches for a common subnet between caller and miner
417            # TODO: use sets
418            allowed_subnets = {
419                subnet: allowed
420                for subnet, allowed in allowed_subnets.items()
421                if (subnet in caller_subnets)
422            }
423            if not allowed_subnets:
424                reason = "Caller key is not registered in any subnet that the miner is"
425                log_reffusal(ss58, reason)
426                return False, json_error(403, reason)
427        else:
428            # accepts everything
429            pass
430
431        return (True, None)

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
InputHandlerVerifier( subnets_whitelist: list[int] | None, module_key: communex.types.Ss58Address, request_staleness: int, blockchain_cache: communex.util.memo.TTLDict[str, list[communex.types.Ss58Address]], host_key: substrateinterface.keypair.Keypair, use_testnet: bool)
177    def __init__(
178        self,
179        subnets_whitelist: list[int] | None,
180        module_key: Ss58Address,
181        request_staleness: int,
182        blockchain_cache: TTLDict[str, list[Ss58Address]],
183        host_key: Keypair,
184        use_testnet: bool,
185    ):
186        self.subnets_whitelist = subnets_whitelist
187        self.module_key = module_key
188        self.request_staleness = request_staleness
189        self.blockchain_cache = blockchain_cache
190        self.host_key = host_key
191        self.use_testnet = use_testnet
subnets_whitelist
module_key
request_staleness
blockchain_cache
host_key
use_testnet
async def verify(self, request: starlette.requests.Request):
193    async def verify(self, request: Request):
194        body = await request.body()
195
196        # TODO: we'll replace this by a Result ADT :)
197        match self._check_inputs(request, body, self.module_key):
198            case (False, error):
199                return error
200            case (True, _):
201                pass
202
203        body_dict: dict[str, dict[str, Any]] = json.loads(body)
204        timestamp = body_dict["params"].get("timestamp", None)
205        legacy_timestamp = request.headers.get("X-Timestamp", None)
206        try:
207            timestamp_to_use = (
208                timestamp if not legacy_timestamp else legacy_timestamp
209            )
210            request_time = datetime.fromisoformat(timestamp_to_use)
211        except Exception:
212            return JSONResponse(
213                status_code=400,
214                content={"error": "Invalid ISO timestamp given"},
215            )
216        if (
217            datetime.now(timezone.utc) - request_time
218        ).total_seconds() > self.request_staleness:
219            return JSONResponse(
220                status_code=400, content={"error": "Request is too stale"}
221            )
222        return None

Please dont mutate the request D:

def build_route_class( verifiers: Sequence[AbstractVerifier]) -> type[fastapi.routing.APIRoute]:
434def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]:
435    class CheckListsRoute(APIRoute):
436        def get_route_handler(self):
437            original_route_handler = super().get_route_handler()
438
439            async def custom_route_handler(
440                request: Request,
441            ) -> Response | JSONResponse:
442                if not request.url.path.startswith("/method"):
443                    unhandled_response: Response = await original_route_handler(
444                        request
445                    )
446                    return unhandled_response
447                for verifier in verifiers:
448                    response = await verifier.verify(request)
449                    if response is not None:
450                        return response
451
452                original_response: Response = await original_route_handler(
453                    request
454                )
455                return original_response
456
457            return custom_route_handler
458
459    return CheckListsRoute