Edit on GitHub

communex.cli.module

  1import importlib.util
  2from typing import Any, Optional, cast
  3
  4import typer
  5import uvicorn
  6from typer import Context
  7
  8import communex.balance as c_balance
  9from communex._common import intersection_update
 10from communex.cli._common import (
 11    make_custom_context,
 12    print_module_info,
 13    print_table_from_plain_dict,
 14)
 15from communex.errors import ChainTransactionError
 16from communex.key import check_ss58_address
 17from communex.misc import get_map_modules
 18from communex.module._rate_limiters.limiters import (
 19    IpLimiterParams,
 20    StakeLimiterParams,
 21)
 22from communex.module.server import ModuleServer
 23from communex.types import Ss58Address
 24from communex.util import is_ip_valid
 25
 26module_app = typer.Typer(no_args_is_help=True)
 27
 28
 29def list_to_ss58(str_list: list[str] | None) -> list[Ss58Address] | None:
 30    """Raises AssertionError if some input is not a valid Ss58Address."""
 31
 32    if str_list is None:
 33        return None
 34    new_list: list[Ss58Address] = []
 35    for item in str_list:
 36        new_item = check_ss58_address(item)
 37        new_list.append(new_item)
 38    return new_list
 39
 40
 41# TODO: refactor module register CLI
 42# - module address should be a single (arbitrary) parameter
 43# - key can be infered from name or vice-versa?
 44@module_app.command()
 45def register(
 46    ctx: Context,
 47    name: str,
 48    key: str,
 49    netuid: int,
 50    ip: Optional[str] = None,
 51    port: Optional[int] = None,
 52    metadata: Optional[str] = None,
 53):
 54    """
 55    Registers a module on a subnet.
 56    """
 57    context = make_custom_context(ctx)
 58    client = context.com_client()
 59    if metadata and len(metadata) > 59:
 60        raise ValueError("Metadata must be less than 60 characters")
 61
 62    burn = client.get_burn(netuid=netuid)
 63
 64    if netuid != 0:
 65        do_burn = context.confirm(
 66            f"{c_balance.from_nano(burn)} $COMAI will be permanently burned. Do you want to continue?"
 67        )
 68
 69        if not do_burn:
 70            context.info("Not registering")
 71            raise typer.Abort()
 72
 73    resolved_key = context.load_key(key, None)
 74
 75    with context.progress_status(f"Registering Module {name}..."):
 76        subnet_name = client.get_subnet_name(netuid)
 77        address = f"{ip}:{port}"
 78
 79        response = client.register_module(
 80            resolved_key,
 81            name=name,
 82            subnet=subnet_name,
 83            address=address,
 84            metadata=metadata,
 85        )
 86
 87        if response.is_success:
 88            context.info(f"Module {name} registered")
 89        else:
 90            raise ChainTransactionError(response.error_message)  # type: ignore
 91
 92
 93@module_app.command()
 94def deregister(ctx: Context, key: str, netuid: int):
 95    """
 96    Deregisters a module from a subnet.
 97    """
 98    context = make_custom_context(ctx)
 99    client = context.com_client()
100
101    resolved_key = context.load_key(key, None)
102
103    with context.progress_status(
104        f"Deregistering your module on subnet {netuid}..."
105    ):
106        response = client.deregister_module(key=resolved_key, netuid=netuid)
107
108        if response.is_success:
109            context.info("Module deregistered")
110        else:
111            raise ChainTransactionError(response.error_message)  # type: ignore
112
113
114@module_app.command()
115def update(
116    ctx: Context,
117    key: str,
118    netuid: int,
119    name: Optional[str] = None,
120    ip: Optional[str] = None,
121    port: Optional[int] = None,
122    delegation_fee: Optional[int] = None,
123    metadata: Optional[str] = None,
124):
125    """
126    Update module with custom parameters.
127    """
128
129    context = make_custom_context(ctx)
130    client = context.com_client()
131
132    if metadata and len(metadata) > 59:
133        raise ValueError("Metadata must be less than 60 characters")
134
135    resolved_key = context.load_key(key, None)
136
137    if ip and not is_ip_valid(ip):
138        raise ValueError("Invalid ip address")
139    modules = get_map_modules(client, netuid=netuid, include_balances=False)
140    modules_to_list = [value for _, value in modules.items()]
141
142    module = next(
143        (
144            item
145            for item in modules_to_list
146            if item["key"] == resolved_key.ss58_address
147        ),
148        None,
149    )
150
151    if module is None:
152        raise ValueError(f"Module {name} not found")
153    module_params = {
154        "name": name,
155        "ip": ip,
156        "port": port,
157        "delegation_fee": delegation_fee,
158        "metadata": metadata,
159    }
160    to_update = {
161        key: value for key, value in module_params.items() if value is not None
162    }
163    current_address = module["address"]
164    if ":" in current_address:
165        current_ip, current_port = current_address.split(":")
166    else:
167        current_ip, current_port = current_address, None
168
169    new_ip = to_update.get("ip", current_ip)
170    new_port = to_update.get("port", current_port)
171
172    if new_port is not None:
173        address = f"{new_ip}:{new_port}"
174    else:
175        address = new_ip
176    to_update["address"] = address
177    updated_module = intersection_update(dict(module), to_update)
178    module.update(updated_module)  # type: ignore
179    with context.progress_status(
180        f"Updating Module on a subnet with netuid '{netuid}' ..."
181    ):
182        response = client.update_module(
183            key=resolved_key,
184            name=module["name"],
185            address=module["address"],
186            delegation_fee=module["delegation_fee"],
187            netuid=netuid,
188            metadata=module["metadata"],
189        )
190
191    if response.is_success:
192        context.info(f"Module {key} updated")
193    else:
194        raise ChainTransactionError(response.error_message)  # type: ignore
195
196
197@module_app.command()
198def serve(
199    ctx: typer.Context,
200    class_path: str,
201    key: str,
202    port: int = 8000,
203    ip: Optional[str] = None,
204    subnets_whitelist: Optional[list[int]] = [0],
205    whitelist: Optional[list[str]] = None,
206    blacklist: Optional[list[str]] = None,
207    ip_blacklist: Optional[list[str]] = None,
208    test_mode: Optional[bool] = False,
209    request_staleness: int = typer.Option(120),
210    use_ip_limiter: Optional[bool] = typer.Option(
211        False, help=("If this value is passed, the ip limiter will be used")
212    ),
213    token_refill_rate_base_multiplier: Optional[int] = typer.Option(
214        None,
215        help=(
216            "Multiply the base limit per stake. Only used in stake limiter mode."
217        ),
218    ),
219):
220    """
221    Serves a module on `127.0.0.1` on port `port`. `class_path` should specify
222    the dotted path to the module class e.g. `module.submodule.ClassName`.
223    """
224    context = make_custom_context(ctx)
225    use_testnet = context.get_use_testnet()
226    path_parts = class_path.split(".")
227    match path_parts:
228        case [*module_parts, class_name]:
229            module_path = ".".join(module_parts)
230            if not module_path:
231                # This could do some kind of relative import somehow?
232                raise ValueError(
233                    f"Invalid class path: `{class_path}`, module name is missing"
234                )
235            if not class_name:
236                raise ValueError(
237                    f"Invalid class path: `{class_path}`, class name is missing"
238                )
239        case _:
240            # This is impossible
241            raise Exception(f"Invalid class path: `{class_path}`")
242
243    try:
244        module = importlib.import_module(module_path)
245    except ModuleNotFoundError:
246        context.error(f"Module `{module_path}` not found")
247        raise typer.Exit(code=1)
248
249    try:
250        class_obj = getattr(module, class_name)
251    except AttributeError:
252        context.error(f"Class `{class_name}` not found in module `{module}`")
253        raise typer.Exit(code=1)
254
255    keypair = context.load_key(key, None)
256
257    if test_mode:
258        subnets_whitelist = None
259    token_refill_rate = token_refill_rate_base_multiplier or 1
260    limiter_params = (
261        IpLimiterParams()
262        if use_ip_limiter
263        else StakeLimiterParams(token_ratio=token_refill_rate)
264    )
265
266    if whitelist is None:
267        context.info(
268            "WARNING: No whitelist provided, will accept calls from any key"
269        )
270
271    try:
272        whitelist_ss58 = list_to_ss58(whitelist)
273    except AssertionError:
274        context.error("Invalid SS58 address passed to whitelist")
275        exit(1)
276    try:
277        blacklist_ss58 = list_to_ss58(blacklist)
278    except AssertionError:
279        context.error("Invalid SS58 address passed on blacklist")
280        exit(1)
281    cast(list[Ss58Address] | None, whitelist)
282
283    server = ModuleServer(
284        class_obj(),
285        keypair,
286        whitelist=whitelist_ss58,
287        blacklist=blacklist_ss58,
288        subnets_whitelist=subnets_whitelist,
289        max_request_staleness=request_staleness,
290        limiter=limiter_params,
291        ip_blacklist=ip_blacklist,
292        use_testnet=use_testnet,
293    )
294    app = server.get_fastapi_app()
295    host = ip or "127.0.0.1"
296    uvicorn.run(app, host=host, port=port)  # type: ignore
297
298
299@module_app.command()
300def info(ctx: Context, name: str, balance: bool = False, netuid: int = 0):
301    """
302    Gets module info
303    """
304    context = make_custom_context(ctx)
305    client = context.com_client()
306
307    with context.progress_status(
308        f"Getting Module {name} on a subnet with netuid {netuid}…"
309    ):
310        modules = get_map_modules(
311            client, netuid=netuid, include_balances=balance
312        )
313        modules_to_list = [value for _, value in modules.items()]
314
315        module = next(
316            (item for item in modules_to_list if item["name"] == name), None
317        )
318
319    if module is None:
320        raise ValueError("Module not found")
321
322    general_module = cast(dict[str, Any], module)
323    print_table_from_plain_dict(
324        general_module, ["Params", "Values"], context.console
325    )
326
327
328@module_app.command(name="list")
329def inventory(ctx: Context, balances: bool = False, netuid: int = 0):
330    """
331    Modules stats on the network.
332    """
333    context = make_custom_context(ctx)
334    client = context.com_client()
335
336    # with context.progress_status(
337    #     f"Getting Modules on a subnet with netuid {netuid}..."
338    # ):
339    modules = cast(
340        dict[str, Any],
341        get_map_modules(client, netuid=netuid, include_balances=balances),
342    )
343
344    # Convert the values to a human readable format
345    modules_to_list = [value for _, value in modules.items()]
346
347    miners: list[Any] = []
348    validators: list[Any] = []
349    inactive: list[Any] = []
350
351    for module in modules_to_list:
352        if module["incentive"] == module["dividends"] == 0:
353            inactive.append(module)
354        elif module["incentive"] > module["dividends"]:
355            miners.append(module)
356        else:
357            validators.append(module)
358
359    print_module_info(client, miners, context.console, netuid, "miners")
360    print_module_info(client, validators, context.console, netuid, "validators")
361    print_module_info(client, inactive, context.console, netuid, "inactive")
module_app = <typer.main.Typer object>
def list_to_ss58(str_list: list[str] | None) -> list[communex.types.Ss58Address] | None:
30def list_to_ss58(str_list: list[str] | None) -> list[Ss58Address] | None:
31    """Raises AssertionError if some input is not a valid Ss58Address."""
32
33    if str_list is None:
34        return None
35    new_list: list[Ss58Address] = []
36    for item in str_list:
37        new_item = check_ss58_address(item)
38        new_list.append(new_item)
39    return new_list

Raises AssertionError if some input is not a valid Ss58Address.

@module_app.command()
def register( ctx: typer.models.Context, name: str, key: str, netuid: int, ip: Optional[str] = None, port: Optional[int] = None, metadata: Optional[str] = None):
45@module_app.command()
46def register(
47    ctx: Context,
48    name: str,
49    key: str,
50    netuid: int,
51    ip: Optional[str] = None,
52    port: Optional[int] = None,
53    metadata: Optional[str] = None,
54):
55    """
56    Registers a module on a subnet.
57    """
58    context = make_custom_context(ctx)
59    client = context.com_client()
60    if metadata and len(metadata) > 59:
61        raise ValueError("Metadata must be less than 60 characters")
62
63    burn = client.get_burn(netuid=netuid)
64
65    if netuid != 0:
66        do_burn = context.confirm(
67            f"{c_balance.from_nano(burn)} $COMAI will be permanently burned. Do you want to continue?"
68        )
69
70        if not do_burn:
71            context.info("Not registering")
72            raise typer.Abort()
73
74    resolved_key = context.load_key(key, None)
75
76    with context.progress_status(f"Registering Module {name}..."):
77        subnet_name = client.get_subnet_name(netuid)
78        address = f"{ip}:{port}"
79
80        response = client.register_module(
81            resolved_key,
82            name=name,
83            subnet=subnet_name,
84            address=address,
85            metadata=metadata,
86        )
87
88        if response.is_success:
89            context.info(f"Module {name} registered")
90        else:
91            raise ChainTransactionError(response.error_message)  # type: ignore

Registers a module on a subnet.

@module_app.command()
def deregister(ctx: typer.models.Context, key: str, netuid: int):
 94@module_app.command()
 95def deregister(ctx: Context, key: str, netuid: int):
 96    """
 97    Deregisters a module from a subnet.
 98    """
 99    context = make_custom_context(ctx)
100    client = context.com_client()
101
102    resolved_key = context.load_key(key, None)
103
104    with context.progress_status(
105        f"Deregistering your module on subnet {netuid}..."
106    ):
107        response = client.deregister_module(key=resolved_key, netuid=netuid)
108
109        if response.is_success:
110            context.info("Module deregistered")
111        else:
112            raise ChainTransactionError(response.error_message)  # type: ignore

Deregisters a module from a subnet.

@module_app.command()
def update( ctx: typer.models.Context, key: str, netuid: int, name: Optional[str] = None, ip: Optional[str] = None, port: Optional[int] = None, delegation_fee: Optional[int] = None, metadata: Optional[str] = None):
115@module_app.command()
116def update(
117    ctx: Context,
118    key: str,
119    netuid: int,
120    name: Optional[str] = None,
121    ip: Optional[str] = None,
122    port: Optional[int] = None,
123    delegation_fee: Optional[int] = None,
124    metadata: Optional[str] = None,
125):
126    """
127    Update module with custom parameters.
128    """
129
130    context = make_custom_context(ctx)
131    client = context.com_client()
132
133    if metadata and len(metadata) > 59:
134        raise ValueError("Metadata must be less than 60 characters")
135
136    resolved_key = context.load_key(key, None)
137
138    if ip and not is_ip_valid(ip):
139        raise ValueError("Invalid ip address")
140    modules = get_map_modules(client, netuid=netuid, include_balances=False)
141    modules_to_list = [value for _, value in modules.items()]
142
143    module = next(
144        (
145            item
146            for item in modules_to_list
147            if item["key"] == resolved_key.ss58_address
148        ),
149        None,
150    )
151
152    if module is None:
153        raise ValueError(f"Module {name} not found")
154    module_params = {
155        "name": name,
156        "ip": ip,
157        "port": port,
158        "delegation_fee": delegation_fee,
159        "metadata": metadata,
160    }
161    to_update = {
162        key: value for key, value in module_params.items() if value is not None
163    }
164    current_address = module["address"]
165    if ":" in current_address:
166        current_ip, current_port = current_address.split(":")
167    else:
168        current_ip, current_port = current_address, None
169
170    new_ip = to_update.get("ip", current_ip)
171    new_port = to_update.get("port", current_port)
172
173    if new_port is not None:
174        address = f"{new_ip}:{new_port}"
175    else:
176        address = new_ip
177    to_update["address"] = address
178    updated_module = intersection_update(dict(module), to_update)
179    module.update(updated_module)  # type: ignore
180    with context.progress_status(
181        f"Updating Module on a subnet with netuid '{netuid}' ..."
182    ):
183        response = client.update_module(
184            key=resolved_key,
185            name=module["name"],
186            address=module["address"],
187            delegation_fee=module["delegation_fee"],
188            netuid=netuid,
189            metadata=module["metadata"],
190        )
191
192    if response.is_success:
193        context.info(f"Module {key} updated")
194    else:
195        raise ChainTransactionError(response.error_message)  # type: ignore

Update module with custom parameters.

@module_app.command()
def serve( ctx: typer.models.Context, class_path: str, key: str, port: int = 8000, ip: Optional[str] = None, subnets_whitelist: Optional[list[int]] = [0], whitelist: Optional[list[str]] = None, blacklist: Optional[list[str]] = None, ip_blacklist: Optional[list[str]] = None, test_mode: Optional[bool] = False, request_staleness: int = <typer.models.OptionInfo object>, use_ip_limiter: Optional[bool] = <typer.models.OptionInfo object>, token_refill_rate_base_multiplier: Optional[int] = <typer.models.OptionInfo object>):
198@module_app.command()
199def serve(
200    ctx: typer.Context,
201    class_path: str,
202    key: str,
203    port: int = 8000,
204    ip: Optional[str] = None,
205    subnets_whitelist: Optional[list[int]] = [0],
206    whitelist: Optional[list[str]] = None,
207    blacklist: Optional[list[str]] = None,
208    ip_blacklist: Optional[list[str]] = None,
209    test_mode: Optional[bool] = False,
210    request_staleness: int = typer.Option(120),
211    use_ip_limiter: Optional[bool] = typer.Option(
212        False, help=("If this value is passed, the ip limiter will be used")
213    ),
214    token_refill_rate_base_multiplier: Optional[int] = typer.Option(
215        None,
216        help=(
217            "Multiply the base limit per stake. Only used in stake limiter mode."
218        ),
219    ),
220):
221    """
222    Serves a module on `127.0.0.1` on port `port`. `class_path` should specify
223    the dotted path to the module class e.g. `module.submodule.ClassName`.
224    """
225    context = make_custom_context(ctx)
226    use_testnet = context.get_use_testnet()
227    path_parts = class_path.split(".")
228    match path_parts:
229        case [*module_parts, class_name]:
230            module_path = ".".join(module_parts)
231            if not module_path:
232                # This could do some kind of relative import somehow?
233                raise ValueError(
234                    f"Invalid class path: `{class_path}`, module name is missing"
235                )
236            if not class_name:
237                raise ValueError(
238                    f"Invalid class path: `{class_path}`, class name is missing"
239                )
240        case _:
241            # This is impossible
242            raise Exception(f"Invalid class path: `{class_path}`")
243
244    try:
245        module = importlib.import_module(module_path)
246    except ModuleNotFoundError:
247        context.error(f"Module `{module_path}` not found")
248        raise typer.Exit(code=1)
249
250    try:
251        class_obj = getattr(module, class_name)
252    except AttributeError:
253        context.error(f"Class `{class_name}` not found in module `{module}`")
254        raise typer.Exit(code=1)
255
256    keypair = context.load_key(key, None)
257
258    if test_mode:
259        subnets_whitelist = None
260    token_refill_rate = token_refill_rate_base_multiplier or 1
261    limiter_params = (
262        IpLimiterParams()
263        if use_ip_limiter
264        else StakeLimiterParams(token_ratio=token_refill_rate)
265    )
266
267    if whitelist is None:
268        context.info(
269            "WARNING: No whitelist provided, will accept calls from any key"
270        )
271
272    try:
273        whitelist_ss58 = list_to_ss58(whitelist)
274    except AssertionError:
275        context.error("Invalid SS58 address passed to whitelist")
276        exit(1)
277    try:
278        blacklist_ss58 = list_to_ss58(blacklist)
279    except AssertionError:
280        context.error("Invalid SS58 address passed on blacklist")
281        exit(1)
282    cast(list[Ss58Address] | None, whitelist)
283
284    server = ModuleServer(
285        class_obj(),
286        keypair,
287        whitelist=whitelist_ss58,
288        blacklist=blacklist_ss58,
289        subnets_whitelist=subnets_whitelist,
290        max_request_staleness=request_staleness,
291        limiter=limiter_params,
292        ip_blacklist=ip_blacklist,
293        use_testnet=use_testnet,
294    )
295    app = server.get_fastapi_app()
296    host = ip or "127.0.0.1"
297    uvicorn.run(app, host=host, port=port)  # type: ignore

Serves a module on 127.0.0.1 on port port. class_path should specify the dotted path to the module class e.g. module.submodule.ClassName.

@module_app.command()
def info( ctx: typer.models.Context, name: str, balance: bool = False, netuid: int = 0):
300@module_app.command()
301def info(ctx: Context, name: str, balance: bool = False, netuid: int = 0):
302    """
303    Gets module info
304    """
305    context = make_custom_context(ctx)
306    client = context.com_client()
307
308    with context.progress_status(
309        f"Getting Module {name} on a subnet with netuid {netuid}…"
310    ):
311        modules = get_map_modules(
312            client, netuid=netuid, include_balances=balance
313        )
314        modules_to_list = [value for _, value in modules.items()]
315
316        module = next(
317            (item for item in modules_to_list if item["name"] == name), None
318        )
319
320    if module is None:
321        raise ValueError("Module not found")
322
323    general_module = cast(dict[str, Any], module)
324    print_table_from_plain_dict(
325        general_module, ["Params", "Values"], context.console
326    )

Gets module info

@module_app.command(name='list')
def inventory(ctx: typer.models.Context, balances: bool = False, netuid: int = 0):
329@module_app.command(name="list")
330def inventory(ctx: Context, balances: bool = False, netuid: int = 0):
331    """
332    Modules stats on the network.
333    """
334    context = make_custom_context(ctx)
335    client = context.com_client()
336
337    # with context.progress_status(
338    #     f"Getting Modules on a subnet with netuid {netuid}..."
339    # ):
340    modules = cast(
341        dict[str, Any],
342        get_map_modules(client, netuid=netuid, include_balances=balances),
343    )
344
345    # Convert the values to a human readable format
346    modules_to_list = [value for _, value in modules.items()]
347
348    miners: list[Any] = []
349    validators: list[Any] = []
350    inactive: list[Any] = []
351
352    for module in modules_to_list:
353        if module["incentive"] == module["dividends"] == 0:
354            inactive.append(module)
355        elif module["incentive"] > module["dividends"]:
356            miners.append(module)
357        else:
358            validators.append(module)
359
360    print_module_info(client, miners, context.console, netuid, "miners")
361    print_module_info(client, validators, context.console, netuid, "validators")
362    print_module_info(client, inactive, context.console, netuid, "inactive")

Modules stats on the network.