Edit on GitHub

communex.module.routers.module_routers

  1import json
  2import re
  3from abc import abstractmethod
  4from datetime import datetime, timezone
  5from functools import partial
  6from typing import Any, Protocol, Sequence
  7
  8import starlette.datastructures
  9from communex._common import get_node_url
 10from communex.module import _signer as signer
 11from communex.module._rate_limiters._stake_limiter import StakeLimiter
 12from communex.module._rate_limiters.limiters import (
 13    IpLimiterParams,
 14    StakeLimiterParams,
 15)
 16from communex.module._util import (
 17    json_error,
 18    log,
 19    log_reffusal,
 20    make_client,
 21    try_ss58_decode,
 22)
 23from communex.types import Ss58Address
 24from communex.util.memo import TTLDict
 25from fastapi import Request, Response
 26from fastapi.responses import JSONResponse
 27from fastapi.routing import APIRoute
 28from keylimiter import TokenBucketLimiter
 29from substrateinterface import Keypair
 30
 31HEX_PATTERN = re.compile(r"^[0-9a-fA-F]+$")
 32
 33
 34def is_hex_string(string: str):
 35    return bool(HEX_PATTERN.match(string))
 36
 37
 38def parse_hex(hex_str: str) -> bytes:
 39    if hex_str[0:2] == "0x":
 40        return bytes.fromhex(hex_str[2:])
 41    else:
 42        return bytes.fromhex(hex_str)
 43
 44
 45class AbstractVerifier(Protocol):
 46    @abstractmethod
 47    async def verify(self, request: Request) -> JSONResponse | None:
 48        """Please dont mutate the request D:"""
 49        ...
 50
 51
 52class StakeLimiterVerifier(AbstractVerifier):
 53    def __init__(
 54        self,
 55        subnets_whitelist: list[int] | None,
 56        params_: StakeLimiterParams | None,
 57    ):
 58        self.subnets_whitelist = subnets_whitelist
 59        self.params_ = params_
 60        if not self.params_:
 61            params = StakeLimiterParams()
 62        else:
 63            params = self.params_
 64        self.limiter = StakeLimiter(
 65            self.subnets_whitelist,
 66            epoch=params.epoch,
 67            max_cache_age=params.cache_age,
 68            get_refill_rate=params.get_refill_per_epoch,
 69        )
 70
 71    async def verify(self, request: Request):
 72        if request.client is None:
 73            response = JSONResponse(
 74                status_code=401,
 75                content={"error": "Address should be present in request"},
 76            )
 77            return response
 78
 79        key = request.headers.get("x-key")
 80        if not key:
 81            response = JSONResponse(
 82                status_code=401,
 83                content={"error": "Valid X-Key not provided on headers"},
 84            )
 85            return response
 86
 87        is_allowed = await self.limiter.allow(key)
 88
 89        if not is_allowed:
 90            response = JSONResponse(
 91                status_code=429,
 92                headers={
 93                    "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"
 94                },
 95                content={"error": "Rate limit exceeded"},
 96            )
 97            return response
 98        return None
 99
100
101class ListVerifier(AbstractVerifier):
102    def __init__(
103        self,
104        blacklist: list[Ss58Address] | None,
105        whitelist: list[Ss58Address] | None,
106        ip_blacklist: list[str] | None,
107    ):
108        self.blacklist = blacklist
109        self.whitelist = whitelist
110        self.ip_blacklist = ip_blacklist
111
112    async def verify(self, request: Request) -> JSONResponse | None:
113        key = request.headers.get("x-key")
114        if not key:
115            reason = "Missing header: X-Key"
116            log(f"INFO: refusing module request because: {reason}")
117            return json_error(400, "Missing header: X-Key")
118
119        ss58 = try_ss58_decode(key)
120        if ss58 is None:
121            reason = "Caller key could not be decoded into a ss58address"
122            log_reffusal(key, reason)
123            return json_error(400, reason)
124        if request.client is None:
125            return json_error(400, "Address should be present in request")
126        if self.blacklist and ss58 in self.blacklist:
127            return json_error(403, "You are blacklisted")
128        if self.ip_blacklist and request.client.host in self.ip_blacklist:
129            return json_error(403, "Your IP is blacklisted")
130        if self.whitelist and ss58 not in self.whitelist:
131            return json_error(403, "You are not whitelisted")
132        return None
133
134
135class IpLimiterVerifier(AbstractVerifier):
136    def __init__(
137        self,
138        params: IpLimiterParams | None,
139    ):
140        """
141        :param limiter: KeyLimiter instance OR None
142
143        If limiter is None, then a default TokenBucketLimiter is used with the following config:
144        bucket_size=200, refill_rate=15
145        """
146
147        # fallback to default limiter
148        if not params:
149            params = IpLimiterParams()
150        self._limiter = TokenBucketLimiter(
151            bucket_size=params.bucket_size, refill_rate=params.refill_rate
152        )
153
154    async def verify(self, request: Request):
155        assert request.client is not None, "request is invalid"
156        assert request.client.host, "request is invalid."
157
158        ip = request.client.host
159
160        is_allowed = self._limiter.allow(ip)
161
162        if not is_allowed:
163            response = JSONResponse(
164                status_code=429,
165                headers={
166                    "X-RateLimit-Remaining": str(self._limiter.remaining(ip))
167                },
168                content={"error": "Rate limit exceeded"},
169            )
170            return response
171        return None
172
173
174class InputHandlerVerifier(AbstractVerifier):
175    def __init__(
176        self,
177        subnets_whitelist: list[int] | None,
178        module_key: Ss58Address,
179        request_staleness: int,
180        blockchain_cache: TTLDict[str, list[Ss58Address]],
181        host_key: Keypair,
182        use_testnet: bool,
183    ):
184        self.subnets_whitelist = subnets_whitelist
185        self.module_key = module_key
186        self.request_staleness = request_staleness
187        self.blockchain_cache = blockchain_cache
188        self.host_key = host_key
189        self.use_testnet = use_testnet
190
191    async def verify(self, request: Request):
192        body = await request.body()
193
194        # TODO: we'll replace this by a Result ADT :)
195        match self._check_inputs(request, body, self.module_key):
196            case (False, error):
197                return error
198            case (True, _):
199                pass
200
201        body_dict: dict[str, dict[str, Any]] = json.loads(body)
202        timestamp = body_dict["params"].get("timestamp", None)
203        legacy_timestamp = request.headers.get("X-Timestamp", None)
204        try:
205            timestamp_to_use = (
206                timestamp if not legacy_timestamp else legacy_timestamp
207            )
208            request_time = datetime.fromisoformat(timestamp_to_use)
209        except Exception:
210            return JSONResponse(
211                status_code=400,
212                content={"error": "Invalid ISO timestamp given"},
213            )
214        if (
215            datetime.now(timezone.utc) - request_time
216        ).total_seconds() > self.request_staleness:
217            return JSONResponse(
218                status_code=400, content={"error": "Request is too stale"}
219            )
220        return None
221
222    def _check_inputs(
223        self, request: Request, body: bytes, module_key: Ss58Address
224    ):
225        required_headers = ["x-signature", "x-key", "x-crypto"]
226        optional_headers = ["x-timestamp"]
227
228        # TODO: we'll replace this by a Result ADT :)
229        match self._get_headers_dict(
230            request.headers, required_headers, optional_headers
231        ):
232            case (False, error):
233                return (False, error)
234            case (True, headers_dict):
235                pass
236
237        # TODO: we'll replace this by a Result ADT :)
238        match self._check_signature(headers_dict, body, module_key):
239            case (False, error):
240                return (False, error)
241            case (True, _):
242                pass
243
244        # TODO: we'll replace this by a Result ADT :)
245        match self._check_key_registered(
246            self.subnets_whitelist,
247            headers_dict,
248            self.blockchain_cache,
249            self.host_key,
250            self.use_testnet,
251        ):
252            case (False, error):
253                return (False, error)
254            case (True, _):
255                pass
256
257        return (True, None)
258
259    def _get_headers_dict(
260        self,
261        headers: starlette.datastructures.Headers,
262        required: list[str],
263        optional: list[str],
264    ):
265        headers_dict: dict[str, str] = {}
266        for required_header in required:
267            value = headers.get(required_header)
268            if not value:
269                code = 400
270                return False, json_error(
271                    code, f"Missing header: {required_header}"
272                )
273            headers_dict[required_header] = value
274        for optional_header in optional:
275            value = headers.get(optional_header)
276            if value:
277                headers_dict[optional_header] = value
278
279        return True, headers_dict
280
281    def _check_signature(
282        self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address
283    ):
284        key = headers_dict["x-key"]
285        signature = headers_dict["x-signature"]
286        crypto = int(headers_dict["x-crypto"])  # TODO: better handling of this
287
288        if not is_hex_string(key):
289            reason = "X-Key should be a hex value"
290            log_reffusal(key, reason)
291            return (False, json_error(400, reason))
292        try:
293            signature = parse_hex(signature)
294        except Exception:
295            reason = "Signature sent is not a valid hex value"
296            log_reffusal(key, reason)
297            return False, json_error(400, reason)
298        try:
299            key = parse_hex(key)
300        except Exception:
301            reason = "Key sent is not a valid hex value"
302            log_reffusal(key, reason)
303            return False, json_error(400, reason)
304        # decodes the key for better logging
305        key_ss58 = try_ss58_decode(key)
306        if key_ss58 is None:
307            reason = "Caller key could not be decoded into a ss58address"
308            log_reffusal(key.decode(), reason)
309            return (False, json_error(400, reason))
310
311        timestamp = headers_dict.get("x-timestamp")
312        legacy_verified = False
313        if timestamp:
314            # tries to do a legacy verification
315            json_body = json.loads(body)
316            json_body["timestamp"] = timestamp
317            stamped_body = json.dumps(json_body).encode()
318            legacy_verified = signer.verify(
319                key, crypto, stamped_body, signature
320            )
321
322        verified = signer.verify(key, crypto, body, signature)
323        if not verified and not legacy_verified:
324            reason = "Signature doesn't match"
325            log_reffusal(key_ss58, reason)
326            return (False, json_error(401, "Signatures doesn't match"))
327
328        body_dict: dict[str, dict[str, Any]] = json.loads(body)
329        target_key = body_dict["params"].get("target_key", None)
330        if not target_key or target_key != module_key:
331            reason = "Wrong target_key in body"
332            log_reffusal(key_ss58, reason)
333            return (False, json_error(401, "Wrong target_key in body"))
334
335        return (True, None)
336
337    def _check_key_registered(
338        self,
339        subnets_whitelist: list[int] | None,
340        headers_dict: dict[str, str],
341        blockchain_cache: TTLDict[str, list[Ss58Address]],
342        host_key: Keypair,
343        use_testnet: bool,
344    ):
345        key = headers_dict["x-key"]
346        if not is_hex_string(key):
347            return (False, json_error(400, "X-Key should be a hex value"))
348        key = parse_hex(key)
349
350        # TODO: checking for key being registered should be smarter
351        # e.g. query and store all registered modules periodically.
352
353        ss58 = try_ss58_decode(key)
354        if ss58 is None:
355            reason = "Caller key could not be decoded into a ss58address"
356            log_reffusal(key.decode(), reason)
357            return (False, json_error(400, reason))
358
359        # If subnets whitelist is specified, checks if key is registered in one
360        # of the given subnets
361
362        allowed_subnets: dict[int, bool] = {}
363        caller_subnets: list[int] = []
364        if subnets_whitelist is not None:
365
366            def query_keys(subnet: int):
367                try:
368                    node_url = get_node_url(None, use_testnet=use_testnet)
369                    client = make_client(
370                        node_url
371                    )  # TODO: get client from outer context
372                    return [*client.query_map_key(subnet).values()]
373                except Exception:
374                    log("WARNING: Could not connect to a blockchain node")
375                    return_list: list[Ss58Address] = []
376                    return return_list
377
378            # TODO: client pool for entire module server
379
380            got_keys = False
381            no_keys_reason = (
382                "Miner could not connect to a blockchain node "
383                "or there is no key registered on the subnet(s) {} "
384            )
385            for subnet in subnets_whitelist:
386                get_keys_on_subnet = partial(query_keys, subnet)
387                cache_key = f"keys_on_subnet_{subnet}"
388                keys_on_subnet = blockchain_cache.get_or_insert_lazy(
389                    cache_key, get_keys_on_subnet
390                )
391                if len(keys_on_subnet) == 0:
392                    reason = no_keys_reason.format(subnet)
393                    log(f"WARNING: {reason}")
394                else:
395                    got_keys = True
396                if host_key.ss58_address not in keys_on_subnet:
397                    log(
398                        f"WARNING: This miner is deregistered on subnet {subnet}"
399                    )
400                else:
401                    allowed_subnets[subnet] = True
402                if ss58 in keys_on_subnet:
403                    caller_subnets.append(subnet)
404            if not got_keys:
405                return False, json_error(
406                    503, no_keys_reason.format(subnets_whitelist)
407                )
408            if not allowed_subnets:
409                log("WARNING: Miner is not registered on any subnet")
410                return False, json_error(
411                    403, "Miner is not registered on any subnet"
412                )
413
414            # searches for a common subnet between caller and miner
415            # TODO: use sets
416            allowed_subnets = {
417                subnet: allowed
418                for subnet, allowed in allowed_subnets.items()
419                if (subnet in caller_subnets)
420            }
421            if not allowed_subnets:
422                reason = "Caller key is not registered in any subnet that the miner is"
423                log_reffusal(ss58, reason)
424                return False, json_error(403, reason)
425        else:
426            # accepts everything
427            pass
428
429        return (True, None)
430
431
432def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]:
433    class CheckListsRoute(APIRoute):
434        def get_route_handler(self):
435            original_route_handler = super().get_route_handler()
436
437            async def custom_route_handler(
438                request: Request,
439            ) -> Response | JSONResponse:
440                if not request.url.path.startswith("/method"):
441                    unhandled_response: Response = await original_route_handler(
442                        request
443                    )
444                    return unhandled_response
445                for verifier in verifiers:
446                    response = await verifier.verify(request)
447                    if response is not None:
448                        return response
449
450                original_response: Response = await original_route_handler(
451                    request
452                )
453                return original_response
454
455            return custom_route_handler
456
457    return CheckListsRoute
HEX_PATTERN = re.compile('^[0-9a-fA-F]+$')
def is_hex_string(string: str):
35def is_hex_string(string: str):
36    return bool(HEX_PATTERN.match(string))
def parse_hex(hex_str: str) -> bytes:
39def parse_hex(hex_str: str) -> bytes:
40    if hex_str[0:2] == "0x":
41        return bytes.fromhex(hex_str[2:])
42    else:
43        return bytes.fromhex(hex_str)
class AbstractVerifier(typing.Protocol):
46class AbstractVerifier(Protocol):
47    @abstractmethod
48    async def verify(self, request: Request) -> JSONResponse | None:
49        """Please dont mutate the request D:"""
50        ...

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
@abstractmethod
async def verify( self, request: starlette.requests.Request) -> starlette.responses.JSONResponse | None:
47    @abstractmethod
48    async def verify(self, request: Request) -> JSONResponse | None:
49        """Please dont mutate the request D:"""
50        ...

Please dont mutate the request D:

class StakeLimiterVerifier(AbstractVerifier):
53class StakeLimiterVerifier(AbstractVerifier):
54    def __init__(
55        self,
56        subnets_whitelist: list[int] | None,
57        params_: StakeLimiterParams | None,
58    ):
59        self.subnets_whitelist = subnets_whitelist
60        self.params_ = params_
61        if not self.params_:
62            params = StakeLimiterParams()
63        else:
64            params = self.params_
65        self.limiter = StakeLimiter(
66            self.subnets_whitelist,
67            epoch=params.epoch,
68            max_cache_age=params.cache_age,
69            get_refill_rate=params.get_refill_per_epoch,
70        )
71
72    async def verify(self, request: Request):
73        if request.client is None:
74            response = JSONResponse(
75                status_code=401,
76                content={"error": "Address should be present in request"},
77            )
78            return response
79
80        key = request.headers.get("x-key")
81        if not key:
82            response = JSONResponse(
83                status_code=401,
84                content={"error": "Valid X-Key not provided on headers"},
85            )
86            return response
87
88        is_allowed = await self.limiter.allow(key)
89
90        if not is_allowed:
91            response = JSONResponse(
92                status_code=429,
93                headers={
94                    "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"
95                },
96                content={"error": "Rate limit exceeded"},
97            )
98            return response
99        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
StakeLimiterVerifier( subnets_whitelist: list[int] | None, params_: communex.module._rate_limiters.limiters.StakeLimiterParams | None)
54    def __init__(
55        self,
56        subnets_whitelist: list[int] | None,
57        params_: StakeLimiterParams | None,
58    ):
59        self.subnets_whitelist = subnets_whitelist
60        self.params_ = params_
61        if not self.params_:
62            params = StakeLimiterParams()
63        else:
64            params = self.params_
65        self.limiter = StakeLimiter(
66            self.subnets_whitelist,
67            epoch=params.epoch,
68            max_cache_age=params.cache_age,
69            get_refill_rate=params.get_refill_per_epoch,
70        )
subnets_whitelist
params_
limiter
async def verify(self, request: starlette.requests.Request):
72    async def verify(self, request: Request):
73        if request.client is None:
74            response = JSONResponse(
75                status_code=401,
76                content={"error": "Address should be present in request"},
77            )
78            return response
79
80        key = request.headers.get("x-key")
81        if not key:
82            response = JSONResponse(
83                status_code=401,
84                content={"error": "Valid X-Key not provided on headers"},
85            )
86            return response
87
88        is_allowed = await self.limiter.allow(key)
89
90        if not is_allowed:
91            response = JSONResponse(
92                status_code=429,
93                headers={
94                    "X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"
95                },
96                content={"error": "Rate limit exceeded"},
97            )
98            return response
99        return None

Please dont mutate the request D:

class ListVerifier(AbstractVerifier):
102class ListVerifier(AbstractVerifier):
103    def __init__(
104        self,
105        blacklist: list[Ss58Address] | None,
106        whitelist: list[Ss58Address] | None,
107        ip_blacklist: list[str] | None,
108    ):
109        self.blacklist = blacklist
110        self.whitelist = whitelist
111        self.ip_blacklist = ip_blacklist
112
113    async def verify(self, request: Request) -> JSONResponse | None:
114        key = request.headers.get("x-key")
115        if not key:
116            reason = "Missing header: X-Key"
117            log(f"INFO: refusing module request because: {reason}")
118            return json_error(400, "Missing header: X-Key")
119
120        ss58 = try_ss58_decode(key)
121        if ss58 is None:
122            reason = "Caller key could not be decoded into a ss58address"
123            log_reffusal(key, reason)
124            return json_error(400, reason)
125        if request.client is None:
126            return json_error(400, "Address should be present in request")
127        if self.blacklist and ss58 in self.blacklist:
128            return json_error(403, "You are blacklisted")
129        if self.ip_blacklist and request.client.host in self.ip_blacklist:
130            return json_error(403, "Your IP is blacklisted")
131        if self.whitelist and ss58 not in self.whitelist:
132            return json_error(403, "You are not whitelisted")
133        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
ListVerifier( blacklist: list[communex.types.Ss58Address] | None, whitelist: list[communex.types.Ss58Address] | None, ip_blacklist: list[str] | None)
103    def __init__(
104        self,
105        blacklist: list[Ss58Address] | None,
106        whitelist: list[Ss58Address] | None,
107        ip_blacklist: list[str] | None,
108    ):
109        self.blacklist = blacklist
110        self.whitelist = whitelist
111        self.ip_blacklist = ip_blacklist
blacklist
whitelist
ip_blacklist
async def verify( self, request: starlette.requests.Request) -> starlette.responses.JSONResponse | None:
113    async def verify(self, request: Request) -> JSONResponse | None:
114        key = request.headers.get("x-key")
115        if not key:
116            reason = "Missing header: X-Key"
117            log(f"INFO: refusing module request because: {reason}")
118            return json_error(400, "Missing header: X-Key")
119
120        ss58 = try_ss58_decode(key)
121        if ss58 is None:
122            reason = "Caller key could not be decoded into a ss58address"
123            log_reffusal(key, reason)
124            return json_error(400, reason)
125        if request.client is None:
126            return json_error(400, "Address should be present in request")
127        if self.blacklist and ss58 in self.blacklist:
128            return json_error(403, "You are blacklisted")
129        if self.ip_blacklist and request.client.host in self.ip_blacklist:
130            return json_error(403, "Your IP is blacklisted")
131        if self.whitelist and ss58 not in self.whitelist:
132            return json_error(403, "You are not whitelisted")
133        return None

Please dont mutate the request D:

class IpLimiterVerifier(AbstractVerifier):
136class IpLimiterVerifier(AbstractVerifier):
137    def __init__(
138        self,
139        params: IpLimiterParams | None,
140    ):
141        """
142        :param limiter: KeyLimiter instance OR None
143
144        If limiter is None, then a default TokenBucketLimiter is used with the following config:
145        bucket_size=200, refill_rate=15
146        """
147
148        # fallback to default limiter
149        if not params:
150            params = IpLimiterParams()
151        self._limiter = TokenBucketLimiter(
152            bucket_size=params.bucket_size, refill_rate=params.refill_rate
153        )
154
155    async def verify(self, request: Request):
156        assert request.client is not None, "request is invalid"
157        assert request.client.host, "request is invalid."
158
159        ip = request.client.host
160
161        is_allowed = self._limiter.allow(ip)
162
163        if not is_allowed:
164            response = JSONResponse(
165                status_code=429,
166                headers={
167                    "X-RateLimit-Remaining": str(self._limiter.remaining(ip))
168                },
169                content={"error": "Rate limit exceeded"},
170            )
171            return response
172        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
IpLimiterVerifier( params: communex.module._rate_limiters.limiters.IpLimiterParams | None)
137    def __init__(
138        self,
139        params: IpLimiterParams | None,
140    ):
141        """
142        :param limiter: KeyLimiter instance OR None
143
144        If limiter is None, then a default TokenBucketLimiter is used with the following config:
145        bucket_size=200, refill_rate=15
146        """
147
148        # fallback to default limiter
149        if not params:
150            params = IpLimiterParams()
151        self._limiter = TokenBucketLimiter(
152            bucket_size=params.bucket_size, refill_rate=params.refill_rate
153        )
Parameters
  • limiter: KeyLimiter instance OR None

If limiter is None, then a default TokenBucketLimiter is used with the following config: bucket_size=200, refill_rate=15

async def verify(self, request: starlette.requests.Request):
155    async def verify(self, request: Request):
156        assert request.client is not None, "request is invalid"
157        assert request.client.host, "request is invalid."
158
159        ip = request.client.host
160
161        is_allowed = self._limiter.allow(ip)
162
163        if not is_allowed:
164            response = JSONResponse(
165                status_code=429,
166                headers={
167                    "X-RateLimit-Remaining": str(self._limiter.remaining(ip))
168                },
169                content={"error": "Rate limit exceeded"},
170            )
171            return response
172        return None

Please dont mutate the request D:

class InputHandlerVerifier(AbstractVerifier):
175class InputHandlerVerifier(AbstractVerifier):
176    def __init__(
177        self,
178        subnets_whitelist: list[int] | None,
179        module_key: Ss58Address,
180        request_staleness: int,
181        blockchain_cache: TTLDict[str, list[Ss58Address]],
182        host_key: Keypair,
183        use_testnet: bool,
184    ):
185        self.subnets_whitelist = subnets_whitelist
186        self.module_key = module_key
187        self.request_staleness = request_staleness
188        self.blockchain_cache = blockchain_cache
189        self.host_key = host_key
190        self.use_testnet = use_testnet
191
192    async def verify(self, request: Request):
193        body = await request.body()
194
195        # TODO: we'll replace this by a Result ADT :)
196        match self._check_inputs(request, body, self.module_key):
197            case (False, error):
198                return error
199            case (True, _):
200                pass
201
202        body_dict: dict[str, dict[str, Any]] = json.loads(body)
203        timestamp = body_dict["params"].get("timestamp", None)
204        legacy_timestamp = request.headers.get("X-Timestamp", None)
205        try:
206            timestamp_to_use = (
207                timestamp if not legacy_timestamp else legacy_timestamp
208            )
209            request_time = datetime.fromisoformat(timestamp_to_use)
210        except Exception:
211            return JSONResponse(
212                status_code=400,
213                content={"error": "Invalid ISO timestamp given"},
214            )
215        if (
216            datetime.now(timezone.utc) - request_time
217        ).total_seconds() > self.request_staleness:
218            return JSONResponse(
219                status_code=400, content={"error": "Request is too stale"}
220            )
221        return None
222
223    def _check_inputs(
224        self, request: Request, body: bytes, module_key: Ss58Address
225    ):
226        required_headers = ["x-signature", "x-key", "x-crypto"]
227        optional_headers = ["x-timestamp"]
228
229        # TODO: we'll replace this by a Result ADT :)
230        match self._get_headers_dict(
231            request.headers, required_headers, optional_headers
232        ):
233            case (False, error):
234                return (False, error)
235            case (True, headers_dict):
236                pass
237
238        # TODO: we'll replace this by a Result ADT :)
239        match self._check_signature(headers_dict, body, module_key):
240            case (False, error):
241                return (False, error)
242            case (True, _):
243                pass
244
245        # TODO: we'll replace this by a Result ADT :)
246        match self._check_key_registered(
247            self.subnets_whitelist,
248            headers_dict,
249            self.blockchain_cache,
250            self.host_key,
251            self.use_testnet,
252        ):
253            case (False, error):
254                return (False, error)
255            case (True, _):
256                pass
257
258        return (True, None)
259
260    def _get_headers_dict(
261        self,
262        headers: starlette.datastructures.Headers,
263        required: list[str],
264        optional: list[str],
265    ):
266        headers_dict: dict[str, str] = {}
267        for required_header in required:
268            value = headers.get(required_header)
269            if not value:
270                code = 400
271                return False, json_error(
272                    code, f"Missing header: {required_header}"
273                )
274            headers_dict[required_header] = value
275        for optional_header in optional:
276            value = headers.get(optional_header)
277            if value:
278                headers_dict[optional_header] = value
279
280        return True, headers_dict
281
282    def _check_signature(
283        self, headers_dict: dict[str, str], body: bytes, module_key: Ss58Address
284    ):
285        key = headers_dict["x-key"]
286        signature = headers_dict["x-signature"]
287        crypto = int(headers_dict["x-crypto"])  # TODO: better handling of this
288
289        if not is_hex_string(key):
290            reason = "X-Key should be a hex value"
291            log_reffusal(key, reason)
292            return (False, json_error(400, reason))
293        try:
294            signature = parse_hex(signature)
295        except Exception:
296            reason = "Signature sent is not a valid hex value"
297            log_reffusal(key, reason)
298            return False, json_error(400, reason)
299        try:
300            key = parse_hex(key)
301        except Exception:
302            reason = "Key sent is not a valid hex value"
303            log_reffusal(key, reason)
304            return False, json_error(400, reason)
305        # decodes the key for better logging
306        key_ss58 = try_ss58_decode(key)
307        if key_ss58 is None:
308            reason = "Caller key could not be decoded into a ss58address"
309            log_reffusal(key.decode(), reason)
310            return (False, json_error(400, reason))
311
312        timestamp = headers_dict.get("x-timestamp")
313        legacy_verified = False
314        if timestamp:
315            # tries to do a legacy verification
316            json_body = json.loads(body)
317            json_body["timestamp"] = timestamp
318            stamped_body = json.dumps(json_body).encode()
319            legacy_verified = signer.verify(
320                key, crypto, stamped_body, signature
321            )
322
323        verified = signer.verify(key, crypto, body, signature)
324        if not verified and not legacy_verified:
325            reason = "Signature doesn't match"
326            log_reffusal(key_ss58, reason)
327            return (False, json_error(401, "Signatures doesn't match"))
328
329        body_dict: dict[str, dict[str, Any]] = json.loads(body)
330        target_key = body_dict["params"].get("target_key", None)
331        if not target_key or target_key != module_key:
332            reason = "Wrong target_key in body"
333            log_reffusal(key_ss58, reason)
334            return (False, json_error(401, "Wrong target_key in body"))
335
336        return (True, None)
337
338    def _check_key_registered(
339        self,
340        subnets_whitelist: list[int] | None,
341        headers_dict: dict[str, str],
342        blockchain_cache: TTLDict[str, list[Ss58Address]],
343        host_key: Keypair,
344        use_testnet: bool,
345    ):
346        key = headers_dict["x-key"]
347        if not is_hex_string(key):
348            return (False, json_error(400, "X-Key should be a hex value"))
349        key = parse_hex(key)
350
351        # TODO: checking for key being registered should be smarter
352        # e.g. query and store all registered modules periodically.
353
354        ss58 = try_ss58_decode(key)
355        if ss58 is None:
356            reason = "Caller key could not be decoded into a ss58address"
357            log_reffusal(key.decode(), reason)
358            return (False, json_error(400, reason))
359
360        # If subnets whitelist is specified, checks if key is registered in one
361        # of the given subnets
362
363        allowed_subnets: dict[int, bool] = {}
364        caller_subnets: list[int] = []
365        if subnets_whitelist is not None:
366
367            def query_keys(subnet: int):
368                try:
369                    node_url = get_node_url(None, use_testnet=use_testnet)
370                    client = make_client(
371                        node_url
372                    )  # TODO: get client from outer context
373                    return [*client.query_map_key(subnet).values()]
374                except Exception:
375                    log("WARNING: Could not connect to a blockchain node")
376                    return_list: list[Ss58Address] = []
377                    return return_list
378
379            # TODO: client pool for entire module server
380
381            got_keys = False
382            no_keys_reason = (
383                "Miner could not connect to a blockchain node "
384                "or there is no key registered on the subnet(s) {} "
385            )
386            for subnet in subnets_whitelist:
387                get_keys_on_subnet = partial(query_keys, subnet)
388                cache_key = f"keys_on_subnet_{subnet}"
389                keys_on_subnet = blockchain_cache.get_or_insert_lazy(
390                    cache_key, get_keys_on_subnet
391                )
392                if len(keys_on_subnet) == 0:
393                    reason = no_keys_reason.format(subnet)
394                    log(f"WARNING: {reason}")
395                else:
396                    got_keys = True
397                if host_key.ss58_address not in keys_on_subnet:
398                    log(
399                        f"WARNING: This miner is deregistered on subnet {subnet}"
400                    )
401                else:
402                    allowed_subnets[subnet] = True
403                if ss58 in keys_on_subnet:
404                    caller_subnets.append(subnet)
405            if not got_keys:
406                return False, json_error(
407                    503, no_keys_reason.format(subnets_whitelist)
408                )
409            if not allowed_subnets:
410                log("WARNING: Miner is not registered on any subnet")
411                return False, json_error(
412                    403, "Miner is not registered on any subnet"
413                )
414
415            # searches for a common subnet between caller and miner
416            # TODO: use sets
417            allowed_subnets = {
418                subnet: allowed
419                for subnet, allowed in allowed_subnets.items()
420                if (subnet in caller_subnets)
421            }
422            if not allowed_subnets:
423                reason = "Caller key is not registered in any subnet that the miner is"
424                log_reffusal(ss58, reason)
425                return False, json_error(403, reason)
426        else:
427            # accepts everything
428            pass
429
430        return (True, None)

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
InputHandlerVerifier( subnets_whitelist: list[int] | None, module_key: communex.types.Ss58Address, request_staleness: int, blockchain_cache: communex.util.memo.TTLDict[str, list[communex.types.Ss58Address]], host_key: substrateinterface.keypair.Keypair, use_testnet: bool)
176    def __init__(
177        self,
178        subnets_whitelist: list[int] | None,
179        module_key: Ss58Address,
180        request_staleness: int,
181        blockchain_cache: TTLDict[str, list[Ss58Address]],
182        host_key: Keypair,
183        use_testnet: bool,
184    ):
185        self.subnets_whitelist = subnets_whitelist
186        self.module_key = module_key
187        self.request_staleness = request_staleness
188        self.blockchain_cache = blockchain_cache
189        self.host_key = host_key
190        self.use_testnet = use_testnet
subnets_whitelist
module_key
request_staleness
blockchain_cache
host_key
use_testnet
async def verify(self, request: starlette.requests.Request):
192    async def verify(self, request: Request):
193        body = await request.body()
194
195        # TODO: we'll replace this by a Result ADT :)
196        match self._check_inputs(request, body, self.module_key):
197            case (False, error):
198                return error
199            case (True, _):
200                pass
201
202        body_dict: dict[str, dict[str, Any]] = json.loads(body)
203        timestamp = body_dict["params"].get("timestamp", None)
204        legacy_timestamp = request.headers.get("X-Timestamp", None)
205        try:
206            timestamp_to_use = (
207                timestamp if not legacy_timestamp else legacy_timestamp
208            )
209            request_time = datetime.fromisoformat(timestamp_to_use)
210        except Exception:
211            return JSONResponse(
212                status_code=400,
213                content={"error": "Invalid ISO timestamp given"},
214            )
215        if (
216            datetime.now(timezone.utc) - request_time
217        ).total_seconds() > self.request_staleness:
218            return JSONResponse(
219                status_code=400, content={"error": "Request is too stale"}
220            )
221        return None

Please dont mutate the request D:

def build_route_class( verifiers: Sequence[AbstractVerifier]) -> type[fastapi.routing.APIRoute]:
433def build_route_class(verifiers: Sequence[AbstractVerifier]) -> type[APIRoute]:
434    class CheckListsRoute(APIRoute):
435        def get_route_handler(self):
436            original_route_handler = super().get_route_handler()
437
438            async def custom_route_handler(
439                request: Request,
440            ) -> Response | JSONResponse:
441                if not request.url.path.startswith("/method"):
442                    unhandled_response: Response = await original_route_handler(
443                        request
444                    )
445                    return unhandled_response
446                for verifier in verifiers:
447                    response = await verifier.verify(request)
448                    if response is not None:
449                        return response
450
451                original_response: Response = await original_route_handler(
452                    request
453                )
454                return original_response
455
456            return custom_route_handler
457
458    return CheckListsRoute