Edit on GitHub

communex.module.routers.module_routers

  1import json
  2import re
  3from abc import abstractmethod
  4from datetime import datetime, timezone
  5from functools import partial
  6from typing import Any, Protocol, Sequence
  7
  8import starlette.datastructures
  9from fastapi import Request, Response
 10from fastapi.responses import JSONResponse
 11from fastapi.routing import APIRoute
 12from keylimiter import TokenBucketLimiter
 13from substrateinterface import Keypair  # type: ignore
 14
 15from communex._common import get_node_url
 16from communex.module import _signer as signer
 17from communex.module._rate_limiters._stake_limiter import StakeLimiter
 18from communex.module._rate_limiters.limiters import (IpLimiterParams,
 19                                                     StakeLimiterParams)
 20from communex.module._util import (json_error, log, log_reffusal, make_client,
 21                                   try_ss58_decode)
 22from communex.types import Ss58Address
 23from communex.util.memo import TTLDict
 24
 25HEX_PATTERN = re.compile(r"^[0-9a-fA-F]+$")
 26
 27
 28def is_hex_string(string: str):
 29    return bool(HEX_PATTERN.match(string))
 30
 31
 32def parse_hex(hex_str: str) -> bytes:
 33    if hex_str[0:2] == "0x":
 34        return bytes.fromhex(hex_str[2:])
 35    else:
 36        return bytes.fromhex(hex_str)
 37
 38
 39class AbstractVerifier(Protocol):
 40    @abstractmethod
 41    async def verify(self, request: Request) -> JSONResponse | None:
 42        """Please dont mutate the request D:"""
 43        ...
 44
 45
 46class StakeLimiterVerifier(AbstractVerifier):
 47    def __init__(
 48        self,
 49        subnets_whitelist: list[int] | None,
 50        params_: StakeLimiterParams | None,
 51    ):
 52        self.subnets_whitelist = subnets_whitelist
 53        self.params_ = params_
 54        if not self.params_:
 55            params = StakeLimiterParams()
 56        else:
 57            params = self.params_
 58        self.limiter = StakeLimiter(
 59            self.subnets_whitelist,
 60            epoch=params.epoch,
 61            max_cache_age=params.cache_age,
 62            get_refill_rate=params.get_refill_per_epoch,
 63        )
 64
 65    async def verify(self, request: Request):
 66
 67        if request.client is None:
 68            response = JSONResponse(
 69                status_code=401,
 70                content={
 71                    "error": "Address should be present in request"
 72                }
 73            )
 74            return response
 75
 76        key = request.headers.get('x-key')
 77        if not key:
 78            response = JSONResponse(
 79                status_code=401,
 80                content={"error": "Valid X-Key not provided on headers"}
 81            )
 82            return response
 83
 84        is_allowed = await self.limiter.allow(key)
 85
 86        if not is_allowed:
 87            response = JSONResponse(
 88                status_code=429,
 89                headers={"X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"},
 90                content={"error": "Rate limit exceeded"}
 91            )
 92            return response
 93        return None
 94
 95
 96class ListVerifier(AbstractVerifier):
 97    def __init__(
 98            self,
 99            blacklist: list[Ss58Address] | None,
100            whitelist: list[Ss58Address] | None,
101            ip_blacklist: list[str] | None,
102    ):
103        self.blacklist = blacklist
104        self.whitelist = whitelist
105        self.ip_blacklist = ip_blacklist
106
107    async def verify(self, request: Request) -> JSONResponse | None:
108        key = request.headers.get("x-key")
109        if not key:
110            reason = "Missing header: X-Key"
111            log(f"INFO: refusing module request because: {reason}")
112            return json_error(400, "Missing header: X-Key")
113
114        ss58 = try_ss58_decode(key)
115        if ss58 is None:
116            reason = "Caller key could not be decoded into a ss58address"
117            log_reffusal(key, reason)
118            return json_error(400, reason)
119        if request.client is None:
120            return json_error(400, "Address should be present in request")
121        if self.blacklist and ss58 in self.blacklist:
122            return json_error(403, "You are blacklisted")
123        if self.ip_blacklist and request.client.host in self.ip_blacklist:
124            return json_error(403, "Your IP is blacklisted")
125        if self.whitelist and ss58 not in self.whitelist:
126            return json_error(403, "You are not whitelisted")
127        return None
128
129
130class IpLimiterVerifier(AbstractVerifier):
131    def __init__(
132            self,
133            params: IpLimiterParams | None,
134    ):
135        '''
136        :param limiter: KeyLimiter instance OR None
137
138        If limiter is None, then a default TokenBucketLimiter is used with the following config:
139        bucket_size=200, refill_rate=15
140        '''
141
142        # fallback to default limiter
143        if not params:
144            params = IpLimiterParams()
145        self._limiter = TokenBucketLimiter(
146            bucket_size=params.bucket_size,
147            refill_rate=params.refill_rate
148        )
149
150    async def verify(self, request: Request):
151        assert request.client is not None, "request is invalid"
152        assert request.client.host, "request is invalid."
153
154        ip = request.client.host
155
156        is_allowed = self._limiter.allow(ip)
157
158        if not is_allowed:
159            response = JSONResponse(
160                status_code=429,
161                headers={"X-RateLimit-Remaining": str(self._limiter.remaining(ip))},
162                content={"error": "Rate limit exceeded"}
163            )
164            return response
165        return None
166
167
168class InputHandlerVerifier(AbstractVerifier):
169    def __init__(
170        self,
171        subnets_whitelist: list[int] | None,
172        module_key: Ss58Address,
173        request_staleness: int,
174        blockchain_cache: TTLDict[str, list[Ss58Address]],
175        host_key: Keypair,
176        use_testnet: bool,
177    ):
178        self.subnets_whitelist = subnets_whitelist
179        self.module_key = module_key
180        self.request_staleness = request_staleness
181        self.blockchain_cache = blockchain_cache
182        self.host_key = host_key
183        self.use_testnet = use_testnet
184
185    async def verify(self, request: Request):
186        body = await request.body()
187
188        # TODO: we'll replace this by a Result ADT :)
189        match self._check_inputs(request, body, self.module_key):
190            case (False, error):
191                return error
192            case (True, _):
193                pass
194
195        body_dict: dict[str, dict[str, Any]] = json.loads(body)
196        timestamp = body_dict['params'].get("timestamp", None)
197        legacy_timestamp = request.headers.get("X-Timestamp", None)
198        try:
199            timestamp_to_use = timestamp if not legacy_timestamp else legacy_timestamp
200            request_time = datetime.fromisoformat(timestamp_to_use)
201        except Exception:
202            return JSONResponse(status_code=400, content={"error": "Invalid ISO timestamp given"})
203        if (datetime.now(timezone.utc) - request_time).total_seconds() > self.request_staleness:
204            return JSONResponse(status_code=400, content={"error": "Request is too stale"})
205        return None
206
207    def _check_inputs(
208        self, request: Request,
209        body: bytes,
210        module_key: Ss58Address
211    ):
212        required_headers = ["x-signature", "x-key", "x-crypto"]
213        optional_headers = ["x-timestamp"]
214
215        # TODO: we'll replace this by a Result ADT :)
216        match self._get_headers_dict(request.headers, required_headers, optional_headers):
217            case (False, error):
218                return (False, error)
219            case (True, headers_dict):
220                pass
221
222        # TODO: we'll replace this by a Result ADT :)
223        match self._check_signature(headers_dict, body, module_key):
224            case (False, error):
225                return (False, error)
226            case (True, _):
227                pass
228
229        # TODO: we'll replace this by a Result ADT :)
230        match self._check_key_registered(
231            self.subnets_whitelist,
232            headers_dict,
233            self.blockchain_cache,
234            self.host_key,
235            self.use_testnet,
236        ):
237            case (False, error):
238                return (False, error)
239            case (True, _):
240                pass
241
242        return (True, None)
243
244    def _get_headers_dict(
245        self,
246        headers: starlette.datastructures.Headers,
247        required: list[str],
248        optional: list[str],
249    ):
250        headers_dict: dict[str, str] = {}
251        for required_header in required:
252            value = headers.get(required_header)
253            if not value:
254                code = 400
255                return False, json_error(code, f"Missing header: {required_header}")
256            headers_dict[required_header] = value
257        for optional_header in optional:
258            value = headers.get(optional_header)
259            if value:
260                headers_dict[optional_header] = value
261
262        return True, headers_dict
263
264    def _check_signature(
265        self,
266        headers_dict: dict[str, str],
267        body: bytes,
268        module_key: Ss58Address
269    ):
270        key = headers_dict["x-key"]
271        signature = headers_dict["x-signature"]
272        crypto = int(headers_dict["x-crypto"])  # TODO: better handling of this
273
274        if not is_hex_string(key):
275            reason = "X-Key should be a hex value"
276            log_reffusal(key, reason)
277            return (False, json_error(400, reason))
278        try:
279            signature = parse_hex(signature)
280        except Exception:
281            reason = "Signature sent is not a valid hex value"
282            log_reffusal(key, reason)
283            return False, json_error(400, reason)
284        try:
285            key = parse_hex(key)
286        except Exception:
287            reason = "Key sent is not a valid hex value"
288            log_reffusal(key, reason)
289            return False, json_error(400, reason)
290        # decodes the key for better logging
291        key_ss58 = try_ss58_decode(key)
292        if key_ss58 is None:
293            reason = "Caller key could not be decoded into a ss58address"
294            log_reffusal(key.decode(), reason)
295            return (False, json_error(400, reason))
296
297        timestamp = headers_dict.get("x-timestamp")
298        legacy_verified = False
299        if timestamp:
300            # tries to do a legacy verification
301            json_body = json.loads(body)
302            json_body["timestamp"] = timestamp
303            stamped_body = json.dumps(json_body).encode()
304            legacy_verified = signer.verify(key, crypto, stamped_body, signature)
305
306        verified = signer.verify(key, crypto, body, signature)
307        if not verified and not legacy_verified:
308            reason = "Signature doesn't match"
309            log_reffusal(key_ss58, reason)
310            return (False, json_error(401, "Signatures doesn't match"))
311
312        body_dict: dict[str, dict[str, Any]] = json.loads(body)
313        target_key = body_dict['params'].get("target_key", None)
314        if not target_key or target_key != module_key:
315            reason = "Wrong target_key in body"
316            log_reffusal(key_ss58, reason)
317            return (False, json_error(401, "Wrong target_key in body"))
318
319        return (True, None)
320
321    def _check_key_registered(
322        self,
323        subnets_whitelist: list[int] | None,
324        headers_dict: dict[str, str],
325        blockchain_cache: TTLDict[str, list[Ss58Address]],
326        host_key: Keypair,
327        use_testnet: bool,
328    ):
329        key = headers_dict["x-key"]
330        if not is_hex_string(key):
331            return (False, json_error(400, "X-Key should be a hex value"))
332        key = parse_hex(key)
333
334        # TODO: checking for key being registered should be smarter
335        # e.g. query and store all registered modules periodically.
336
337        ss58 = try_ss58_decode(key)
338        if ss58 is None:
339            reason = "Caller key could not be decoded into a ss58address"
340            log_reffusal(key.decode(), reason)
341            return (False, json_error(400, reason))
342
343        # If subnets whitelist is specified, checks if key is registered in one
344        # of the given subnets
345
346        allowed_subnets: dict[int, bool] = {}
347        caller_subnets: list[int] = []
348        if subnets_whitelist is not None:
349            def query_keys(subnet: int):
350                try:
351                    node_url = get_node_url(None, use_testnet=use_testnet)
352                    client = make_client(node_url)  # TODO: get client from outer context
353                    return [*client.query_map_key(subnet).values()]
354                except Exception:
355                    log(
356                        "WARNING: Could not connect to a blockchain node"
357                    )
358                    return_list: list[Ss58Address] = []
359                    return return_list
360
361            # TODO: client pool for entire module server
362
363            got_keys = False
364            no_keys_reason = (
365                "Miner could not connect to a blockchain node "
366                "or there is no key registered on the subnet(s) {} "
367            )
368            for subnet in subnets_whitelist:
369                get_keys_on_subnet = partial(query_keys, subnet)
370                cache_key = f"keys_on_subnet_{subnet}"
371                keys_on_subnet = blockchain_cache.get_or_insert_lazy(
372                    cache_key, get_keys_on_subnet
373                )
374                if len(keys_on_subnet) == 0:
375                    reason = no_keys_reason.format(subnet)
376                    log(f"WARNING: {reason}")
377                else:
378                    got_keys = True
379                if host_key.ss58_address not in keys_on_subnet:
380                    log(
381                        f"WARNING: This miner is deregistered on subnet {subnet}"
382                    )
383                else:
384                    allowed_subnets[subnet] = True
385                if ss58 in keys_on_subnet:
386                    caller_subnets.append(subnet)
387            if not got_keys:
388                return False, json_error(503, no_keys_reason.format(subnets_whitelist))
389            if not allowed_subnets:
390                log("WARNING: Miner is not registered on any subnet")
391                return False, json_error(403, "Miner is not registered on any subnet")
392
393            # searches for a common subnet between caller and miner
394            # TODO: use sets
395            allowed_subnets = {
396                subnet: allowed for subnet, allowed in allowed_subnets.items() if (
397                    subnet in caller_subnets
398                )
399            }
400            if not allowed_subnets:
401                reason = "Caller key is not registered in any subnet that the miner is"
402                log_reffusal(ss58, reason)
403                return False, json_error(
404                    403, reason
405                )
406        else:
407            # accepts everything
408            pass
409
410        return (True, None)
411
412
413def build_route_class(
414    verifiers: Sequence[AbstractVerifier]
415) -> type[APIRoute]:
416
417    class CheckListsRoute(APIRoute):
418        def get_route_handler(self):
419            original_route_handler = super().get_route_handler()
420
421            async def custom_route_handler(request: Request) -> Response | JSONResponse:
422                if not request.url.path.startswith('/method'):
423                    unhandled_response: Response = await original_route_handler(request)
424                    return unhandled_response
425                for verifier in verifiers:
426                    response = await verifier.verify(request)
427                    if response is not None:
428                        return response
429
430                original_response: Response = await original_route_handler(request)
431                return original_response
432
433            return custom_route_handler
434
435    return CheckListsRoute
HEX_PATTERN = re.compile('^[0-9a-fA-F]+$')
def is_hex_string(string: str):
29def is_hex_string(string: str):
30    return bool(HEX_PATTERN.match(string))
def parse_hex(hex_str: str) -> bytes:
33def parse_hex(hex_str: str) -> bytes:
34    if hex_str[0:2] == "0x":
35        return bytes.fromhex(hex_str[2:])
36    else:
37        return bytes.fromhex(hex_str)
class AbstractVerifier(typing.Protocol):
40class AbstractVerifier(Protocol):
41    @abstractmethod
42    async def verify(self, request: Request) -> JSONResponse | None:
43        """Please dont mutate the request D:"""
44        ...

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
@abstractmethod
async def verify( self, request: starlette.requests.Request) -> starlette.responses.JSONResponse | None:
41    @abstractmethod
42    async def verify(self, request: Request) -> JSONResponse | None:
43        """Please dont mutate the request D:"""
44        ...

Please dont mutate the request D:

class StakeLimiterVerifier(AbstractVerifier):
47class StakeLimiterVerifier(AbstractVerifier):
48    def __init__(
49        self,
50        subnets_whitelist: list[int] | None,
51        params_: StakeLimiterParams | None,
52    ):
53        self.subnets_whitelist = subnets_whitelist
54        self.params_ = params_
55        if not self.params_:
56            params = StakeLimiterParams()
57        else:
58            params = self.params_
59        self.limiter = StakeLimiter(
60            self.subnets_whitelist,
61            epoch=params.epoch,
62            max_cache_age=params.cache_age,
63            get_refill_rate=params.get_refill_per_epoch,
64        )
65
66    async def verify(self, request: Request):
67
68        if request.client is None:
69            response = JSONResponse(
70                status_code=401,
71                content={
72                    "error": "Address should be present in request"
73                }
74            )
75            return response
76
77        key = request.headers.get('x-key')
78        if not key:
79            response = JSONResponse(
80                status_code=401,
81                content={"error": "Valid X-Key not provided on headers"}
82            )
83            return response
84
85        is_allowed = await self.limiter.allow(key)
86
87        if not is_allowed:
88            response = JSONResponse(
89                status_code=429,
90                headers={"X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"},
91                content={"error": "Rate limit exceeded"}
92            )
93            return response
94        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
StakeLimiterVerifier( subnets_whitelist: list[int] | None, params_: communex.module._rate_limiters.limiters.StakeLimiterParams | None)
48    def __init__(
49        self,
50        subnets_whitelist: list[int] | None,
51        params_: StakeLimiterParams | None,
52    ):
53        self.subnets_whitelist = subnets_whitelist
54        self.params_ = params_
55        if not self.params_:
56            params = StakeLimiterParams()
57        else:
58            params = self.params_
59        self.limiter = StakeLimiter(
60            self.subnets_whitelist,
61            epoch=params.epoch,
62            max_cache_age=params.cache_age,
63            get_refill_rate=params.get_refill_per_epoch,
64        )
subnets_whitelist
params_
limiter
async def verify(self, request: starlette.requests.Request):
66    async def verify(self, request: Request):
67
68        if request.client is None:
69            response = JSONResponse(
70                status_code=401,
71                content={
72                    "error": "Address should be present in request"
73                }
74            )
75            return response
76
77        key = request.headers.get('x-key')
78        if not key:
79            response = JSONResponse(
80                status_code=401,
81                content={"error": "Valid X-Key not provided on headers"}
82            )
83            return response
84
85        is_allowed = await self.limiter.allow(key)
86
87        if not is_allowed:
88            response = JSONResponse(
89                status_code=429,
90                headers={"X-RateLimit-TryAfter": f"{str(await self.limiter.retry_after(key))} seconds"},
91                content={"error": "Rate limit exceeded"}
92            )
93            return response
94        return None

Please dont mutate the request D:

class ListVerifier(AbstractVerifier):
 97class ListVerifier(AbstractVerifier):
 98    def __init__(
 99            self,
100            blacklist: list[Ss58Address] | None,
101            whitelist: list[Ss58Address] | None,
102            ip_blacklist: list[str] | None,
103    ):
104        self.blacklist = blacklist
105        self.whitelist = whitelist
106        self.ip_blacklist = ip_blacklist
107
108    async def verify(self, request: Request) -> JSONResponse | None:
109        key = request.headers.get("x-key")
110        if not key:
111            reason = "Missing header: X-Key"
112            log(f"INFO: refusing module request because: {reason}")
113            return json_error(400, "Missing header: X-Key")
114
115        ss58 = try_ss58_decode(key)
116        if ss58 is None:
117            reason = "Caller key could not be decoded into a ss58address"
118            log_reffusal(key, reason)
119            return json_error(400, reason)
120        if request.client is None:
121            return json_error(400, "Address should be present in request")
122        if self.blacklist and ss58 in self.blacklist:
123            return json_error(403, "You are blacklisted")
124        if self.ip_blacklist and request.client.host in self.ip_blacklist:
125            return json_error(403, "Your IP is blacklisted")
126        if self.whitelist and ss58 not in self.whitelist:
127            return json_error(403, "You are not whitelisted")
128        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
ListVerifier( blacklist: list[communex.types.Ss58Address] | None, whitelist: list[communex.types.Ss58Address] | None, ip_blacklist: list[str] | None)
 98    def __init__(
 99            self,
100            blacklist: list[Ss58Address] | None,
101            whitelist: list[Ss58Address] | None,
102            ip_blacklist: list[str] | None,
103    ):
104        self.blacklist = blacklist
105        self.whitelist = whitelist
106        self.ip_blacklist = ip_blacklist
blacklist
whitelist
ip_blacklist
async def verify( self, request: starlette.requests.Request) -> starlette.responses.JSONResponse | None:
108    async def verify(self, request: Request) -> JSONResponse | None:
109        key = request.headers.get("x-key")
110        if not key:
111            reason = "Missing header: X-Key"
112            log(f"INFO: refusing module request because: {reason}")
113            return json_error(400, "Missing header: X-Key")
114
115        ss58 = try_ss58_decode(key)
116        if ss58 is None:
117            reason = "Caller key could not be decoded into a ss58address"
118            log_reffusal(key, reason)
119            return json_error(400, reason)
120        if request.client is None:
121            return json_error(400, "Address should be present in request")
122        if self.blacklist and ss58 in self.blacklist:
123            return json_error(403, "You are blacklisted")
124        if self.ip_blacklist and request.client.host in self.ip_blacklist:
125            return json_error(403, "Your IP is blacklisted")
126        if self.whitelist and ss58 not in self.whitelist:
127            return json_error(403, "You are not whitelisted")
128        return None

Please dont mutate the request D:

class IpLimiterVerifier(AbstractVerifier):
131class IpLimiterVerifier(AbstractVerifier):
132    def __init__(
133            self,
134            params: IpLimiterParams | None,
135    ):
136        '''
137        :param limiter: KeyLimiter instance OR None
138
139        If limiter is None, then a default TokenBucketLimiter is used with the following config:
140        bucket_size=200, refill_rate=15
141        '''
142
143        # fallback to default limiter
144        if not params:
145            params = IpLimiterParams()
146        self._limiter = TokenBucketLimiter(
147            bucket_size=params.bucket_size,
148            refill_rate=params.refill_rate
149        )
150
151    async def verify(self, request: Request):
152        assert request.client is not None, "request is invalid"
153        assert request.client.host, "request is invalid."
154
155        ip = request.client.host
156
157        is_allowed = self._limiter.allow(ip)
158
159        if not is_allowed:
160            response = JSONResponse(
161                status_code=429,
162                headers={"X-RateLimit-Remaining": str(self._limiter.remaining(ip))},
163                content={"error": "Rate limit exceeded"}
164            )
165            return response
166        return None

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
IpLimiterVerifier( params: communex.module._rate_limiters.limiters.IpLimiterParams | None)
132    def __init__(
133            self,
134            params: IpLimiterParams | None,
135    ):
136        '''
137        :param limiter: KeyLimiter instance OR None
138
139        If limiter is None, then a default TokenBucketLimiter is used with the following config:
140        bucket_size=200, refill_rate=15
141        '''
142
143        # fallback to default limiter
144        if not params:
145            params = IpLimiterParams()
146        self._limiter = TokenBucketLimiter(
147            bucket_size=params.bucket_size,
148            refill_rate=params.refill_rate
149        )
Parameters
  • limiter: KeyLimiter instance OR None

If limiter is None, then a default TokenBucketLimiter is used with the following config: bucket_size=200, refill_rate=15

async def verify(self, request: starlette.requests.Request):
151    async def verify(self, request: Request):
152        assert request.client is not None, "request is invalid"
153        assert request.client.host, "request is invalid."
154
155        ip = request.client.host
156
157        is_allowed = self._limiter.allow(ip)
158
159        if not is_allowed:
160            response = JSONResponse(
161                status_code=429,
162                headers={"X-RateLimit-Remaining": str(self._limiter.remaining(ip))},
163                content={"error": "Rate limit exceeded"}
164            )
165            return response
166        return None

Please dont mutate the request D:

class InputHandlerVerifier(AbstractVerifier):
169class InputHandlerVerifier(AbstractVerifier):
170    def __init__(
171        self,
172        subnets_whitelist: list[int] | None,
173        module_key: Ss58Address,
174        request_staleness: int,
175        blockchain_cache: TTLDict[str, list[Ss58Address]],
176        host_key: Keypair,
177        use_testnet: bool,
178    ):
179        self.subnets_whitelist = subnets_whitelist
180        self.module_key = module_key
181        self.request_staleness = request_staleness
182        self.blockchain_cache = blockchain_cache
183        self.host_key = host_key
184        self.use_testnet = use_testnet
185
186    async def verify(self, request: Request):
187        body = await request.body()
188
189        # TODO: we'll replace this by a Result ADT :)
190        match self._check_inputs(request, body, self.module_key):
191            case (False, error):
192                return error
193            case (True, _):
194                pass
195
196        body_dict: dict[str, dict[str, Any]] = json.loads(body)
197        timestamp = body_dict['params'].get("timestamp", None)
198        legacy_timestamp = request.headers.get("X-Timestamp", None)
199        try:
200            timestamp_to_use = timestamp if not legacy_timestamp else legacy_timestamp
201            request_time = datetime.fromisoformat(timestamp_to_use)
202        except Exception:
203            return JSONResponse(status_code=400, content={"error": "Invalid ISO timestamp given"})
204        if (datetime.now(timezone.utc) - request_time).total_seconds() > self.request_staleness:
205            return JSONResponse(status_code=400, content={"error": "Request is too stale"})
206        return None
207
208    def _check_inputs(
209        self, request: Request,
210        body: bytes,
211        module_key: Ss58Address
212    ):
213        required_headers = ["x-signature", "x-key", "x-crypto"]
214        optional_headers = ["x-timestamp"]
215
216        # TODO: we'll replace this by a Result ADT :)
217        match self._get_headers_dict(request.headers, required_headers, optional_headers):
218            case (False, error):
219                return (False, error)
220            case (True, headers_dict):
221                pass
222
223        # TODO: we'll replace this by a Result ADT :)
224        match self._check_signature(headers_dict, body, module_key):
225            case (False, error):
226                return (False, error)
227            case (True, _):
228                pass
229
230        # TODO: we'll replace this by a Result ADT :)
231        match self._check_key_registered(
232            self.subnets_whitelist,
233            headers_dict,
234            self.blockchain_cache,
235            self.host_key,
236            self.use_testnet,
237        ):
238            case (False, error):
239                return (False, error)
240            case (True, _):
241                pass
242
243        return (True, None)
244
245    def _get_headers_dict(
246        self,
247        headers: starlette.datastructures.Headers,
248        required: list[str],
249        optional: list[str],
250    ):
251        headers_dict: dict[str, str] = {}
252        for required_header in required:
253            value = headers.get(required_header)
254            if not value:
255                code = 400
256                return False, json_error(code, f"Missing header: {required_header}")
257            headers_dict[required_header] = value
258        for optional_header in optional:
259            value = headers.get(optional_header)
260            if value:
261                headers_dict[optional_header] = value
262
263        return True, headers_dict
264
265    def _check_signature(
266        self,
267        headers_dict: dict[str, str],
268        body: bytes,
269        module_key: Ss58Address
270    ):
271        key = headers_dict["x-key"]
272        signature = headers_dict["x-signature"]
273        crypto = int(headers_dict["x-crypto"])  # TODO: better handling of this
274
275        if not is_hex_string(key):
276            reason = "X-Key should be a hex value"
277            log_reffusal(key, reason)
278            return (False, json_error(400, reason))
279        try:
280            signature = parse_hex(signature)
281        except Exception:
282            reason = "Signature sent is not a valid hex value"
283            log_reffusal(key, reason)
284            return False, json_error(400, reason)
285        try:
286            key = parse_hex(key)
287        except Exception:
288            reason = "Key sent is not a valid hex value"
289            log_reffusal(key, reason)
290            return False, json_error(400, reason)
291        # decodes the key for better logging
292        key_ss58 = try_ss58_decode(key)
293        if key_ss58 is None:
294            reason = "Caller key could not be decoded into a ss58address"
295            log_reffusal(key.decode(), reason)
296            return (False, json_error(400, reason))
297
298        timestamp = headers_dict.get("x-timestamp")
299        legacy_verified = False
300        if timestamp:
301            # tries to do a legacy verification
302            json_body = json.loads(body)
303            json_body["timestamp"] = timestamp
304            stamped_body = json.dumps(json_body).encode()
305            legacy_verified = signer.verify(key, crypto, stamped_body, signature)
306
307        verified = signer.verify(key, crypto, body, signature)
308        if not verified and not legacy_verified:
309            reason = "Signature doesn't match"
310            log_reffusal(key_ss58, reason)
311            return (False, json_error(401, "Signatures doesn't match"))
312
313        body_dict: dict[str, dict[str, Any]] = json.loads(body)
314        target_key = body_dict['params'].get("target_key", None)
315        if not target_key or target_key != module_key:
316            reason = "Wrong target_key in body"
317            log_reffusal(key_ss58, reason)
318            return (False, json_error(401, "Wrong target_key in body"))
319
320        return (True, None)
321
322    def _check_key_registered(
323        self,
324        subnets_whitelist: list[int] | None,
325        headers_dict: dict[str, str],
326        blockchain_cache: TTLDict[str, list[Ss58Address]],
327        host_key: Keypair,
328        use_testnet: bool,
329    ):
330        key = headers_dict["x-key"]
331        if not is_hex_string(key):
332            return (False, json_error(400, "X-Key should be a hex value"))
333        key = parse_hex(key)
334
335        # TODO: checking for key being registered should be smarter
336        # e.g. query and store all registered modules periodically.
337
338        ss58 = try_ss58_decode(key)
339        if ss58 is None:
340            reason = "Caller key could not be decoded into a ss58address"
341            log_reffusal(key.decode(), reason)
342            return (False, json_error(400, reason))
343
344        # If subnets whitelist is specified, checks if key is registered in one
345        # of the given subnets
346
347        allowed_subnets: dict[int, bool] = {}
348        caller_subnets: list[int] = []
349        if subnets_whitelist is not None:
350            def query_keys(subnet: int):
351                try:
352                    node_url = get_node_url(None, use_testnet=use_testnet)
353                    client = make_client(node_url)  # TODO: get client from outer context
354                    return [*client.query_map_key(subnet).values()]
355                except Exception:
356                    log(
357                        "WARNING: Could not connect to a blockchain node"
358                    )
359                    return_list: list[Ss58Address] = []
360                    return return_list
361
362            # TODO: client pool for entire module server
363
364            got_keys = False
365            no_keys_reason = (
366                "Miner could not connect to a blockchain node "
367                "or there is no key registered on the subnet(s) {} "
368            )
369            for subnet in subnets_whitelist:
370                get_keys_on_subnet = partial(query_keys, subnet)
371                cache_key = f"keys_on_subnet_{subnet}"
372                keys_on_subnet = blockchain_cache.get_or_insert_lazy(
373                    cache_key, get_keys_on_subnet
374                )
375                if len(keys_on_subnet) == 0:
376                    reason = no_keys_reason.format(subnet)
377                    log(f"WARNING: {reason}")
378                else:
379                    got_keys = True
380                if host_key.ss58_address not in keys_on_subnet:
381                    log(
382                        f"WARNING: This miner is deregistered on subnet {subnet}"
383                    )
384                else:
385                    allowed_subnets[subnet] = True
386                if ss58 in keys_on_subnet:
387                    caller_subnets.append(subnet)
388            if not got_keys:
389                return False, json_error(503, no_keys_reason.format(subnets_whitelist))
390            if not allowed_subnets:
391                log("WARNING: Miner is not registered on any subnet")
392                return False, json_error(403, "Miner is not registered on any subnet")
393
394            # searches for a common subnet between caller and miner
395            # TODO: use sets
396            allowed_subnets = {
397                subnet: allowed for subnet, allowed in allowed_subnets.items() if (
398                    subnet in caller_subnets
399                )
400            }
401            if not allowed_subnets:
402                reason = "Caller key is not registered in any subnet that the miner is"
403                log_reffusal(ss58, reason)
404                return False, json_error(
405                    403, reason
406                )
407        else:
408            # accepts everything
409            pass
410
411        return (True, None)

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
InputHandlerVerifier( subnets_whitelist: list[int] | None, module_key: communex.types.Ss58Address, request_staleness: int, blockchain_cache: communex.util.memo.TTLDict[str, list[communex.types.Ss58Address]], host_key: substrateinterface.keypair.Keypair, use_testnet: bool)
170    def __init__(
171        self,
172        subnets_whitelist: list[int] | None,
173        module_key: Ss58Address,
174        request_staleness: int,
175        blockchain_cache: TTLDict[str, list[Ss58Address]],
176        host_key: Keypair,
177        use_testnet: bool,
178    ):
179        self.subnets_whitelist = subnets_whitelist
180        self.module_key = module_key
181        self.request_staleness = request_staleness
182        self.blockchain_cache = blockchain_cache
183        self.host_key = host_key
184        self.use_testnet = use_testnet
subnets_whitelist
module_key
request_staleness
blockchain_cache
host_key
use_testnet
async def verify(self, request: starlette.requests.Request):
186    async def verify(self, request: Request):
187        body = await request.body()
188
189        # TODO: we'll replace this by a Result ADT :)
190        match self._check_inputs(request, body, self.module_key):
191            case (False, error):
192                return error
193            case (True, _):
194                pass
195
196        body_dict: dict[str, dict[str, Any]] = json.loads(body)
197        timestamp = body_dict['params'].get("timestamp", None)
198        legacy_timestamp = request.headers.get("X-Timestamp", None)
199        try:
200            timestamp_to_use = timestamp if not legacy_timestamp else legacy_timestamp
201            request_time = datetime.fromisoformat(timestamp_to_use)
202        except Exception:
203            return JSONResponse(status_code=400, content={"error": "Invalid ISO timestamp given"})
204        if (datetime.now(timezone.utc) - request_time).total_seconds() > self.request_staleness:
205            return JSONResponse(status_code=400, content={"error": "Request is too stale"})
206        return None

Please dont mutate the request D:

def build_route_class( verifiers: Sequence[AbstractVerifier]) -> type[fastapi.routing.APIRoute]:
414def build_route_class(
415    verifiers: Sequence[AbstractVerifier]
416) -> type[APIRoute]:
417
418    class CheckListsRoute(APIRoute):
419        def get_route_handler(self):
420            original_route_handler = super().get_route_handler()
421
422            async def custom_route_handler(request: Request) -> Response | JSONResponse:
423                if not request.url.path.startswith('/method'):
424                    unhandled_response: Response = await original_route_handler(request)
425                    return unhandled_response
426                for verifier in verifiers:
427                    response = await verifier.verify(request)
428                    if response is not None:
429                        return response
430
431                original_response: Response = await original_route_handler(request)
432                return original_response
433
434            return custom_route_handler
435
436    return CheckListsRoute