Edit on GitHub

communex.module.module

Tools for defining Commune modules.

 1"""
 2Tools for defining Commune modules.
 3"""
 4
 5import inspect
 6from dataclasses import dataclass
 7from typing import Any, Callable, Generic, ParamSpec, TypeVar, cast
 8
 9import pydantic
10from pydantic import BaseModel
11
12T = TypeVar("T")
13P = ParamSpec("P")
14
15
16class EndpointParams(BaseModel):
17    class config:
18        extra = "allow"
19
20
21@dataclass
22class EndpointDefinition(Generic[T, P]):
23    name: str
24    fn: Callable[P, T]
25    params_model: type[EndpointParams]
26
27
28def endpoint(fn: Callable[P, T]) -> Callable[P, T]:
29    sig = inspect.signature(fn)
30    params_model = function_params_to_model(sig)
31    name = fn.__name__
32
33    endpoint_def = EndpointDefinition(name, fn, params_model)
34    fn._endpoint_def = endpoint_def  # type: ignore
35
36    return fn
37
38
39def function_params_to_model(
40    signature: inspect.Signature,
41) -> type[EndpointParams]:
42    fields: dict[str, tuple[type] | tuple[type, Any]] = {}
43    for i, param in enumerate(signature.parameters.values()):
44        name = param.name
45        if name == "self":  # cursed
46            assert i == 0
47            continue
48        annotation = param.annotation
49        if annotation == param.empty:
50            raise Exception(
51                f"Error: annotation for parameter `{name}` not found"
52            )
53
54        if param.default == param.empty:
55            fields[name] = (annotation, ...)
56        else:
57            fields[name] = (annotation, param.default)
58
59    model: type[EndpointParams] = cast(
60        type[EndpointParams],
61        pydantic.create_model(  #  type: ignore
62            "Params",
63            **fields,  #  type: ignore
64            __base__=EndpointParams,  #  type: ignore
65        ),
66    )
67
68    return model
69
70
71class Module:
72    def __init__(self) -> None:
73        # TODO: is it possible to get this at class creation instead of object instantiation?
74        self.__endpoints = self.extract_endpoints()
75
76    def get_endpoints(self):
77        return self.__endpoints
78
79    def extract_endpoints(self):
80        endpoints: dict[str, EndpointDefinition[Any, Any]] = {}
81        for name, method in inspect.getmembers(
82            self, predicate=inspect.ismethod
83        ):
84            if hasattr(method, "_endpoint_def"):
85                endpoint_def: EndpointDefinition = method._endpoint_def  # type: ignore
86                endpoints[name] = endpoint_def  # type: ignore
87        return endpoints
P = ~P
class EndpointParams(pydantic.main.BaseModel):
17class EndpointParams(BaseModel):
18    class config:
19        extra = "allow"

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
model_fields
model_computed_fields
class EndpointParams.config:
18    class config:
19        extra = "allow"
extra = 'allow'
@dataclass
class EndpointDefinition(typing.Generic[~T, ~P]):
22@dataclass
23class EndpointDefinition(Generic[T, P]):
24    name: str
25    fn: Callable[P, T]
26    params_model: type[EndpointParams]
EndpointDefinition( name: str, fn: Callable[~P, ~T], params_model: type[EndpointParams])
name: str
fn: Callable[~P, ~T]
params_model: type[EndpointParams]
def endpoint(fn: Callable[~P, ~T]) -> Callable[~P, ~T]:
29def endpoint(fn: Callable[P, T]) -> Callable[P, T]:
30    sig = inspect.signature(fn)
31    params_model = function_params_to_model(sig)
32    name = fn.__name__
33
34    endpoint_def = EndpointDefinition(name, fn, params_model)
35    fn._endpoint_def = endpoint_def  # type: ignore
36
37    return fn
def function_params_to_model( signature: inspect.Signature) -> type[EndpointParams]:
40def function_params_to_model(
41    signature: inspect.Signature,
42) -> type[EndpointParams]:
43    fields: dict[str, tuple[type] | tuple[type, Any]] = {}
44    for i, param in enumerate(signature.parameters.values()):
45        name = param.name
46        if name == "self":  # cursed
47            assert i == 0
48            continue
49        annotation = param.annotation
50        if annotation == param.empty:
51            raise Exception(
52                f"Error: annotation for parameter `{name}` not found"
53            )
54
55        if param.default == param.empty:
56            fields[name] = (annotation, ...)
57        else:
58            fields[name] = (annotation, param.default)
59
60    model: type[EndpointParams] = cast(
61        type[EndpointParams],
62        pydantic.create_model(  #  type: ignore
63            "Params",
64            **fields,  #  type: ignore
65            __base__=EndpointParams,  #  type: ignore
66        ),
67    )
68
69    return model
class Module:
72class Module:
73    def __init__(self) -> None:
74        # TODO: is it possible to get this at class creation instead of object instantiation?
75        self.__endpoints = self.extract_endpoints()
76
77    def get_endpoints(self):
78        return self.__endpoints
79
80    def extract_endpoints(self):
81        endpoints: dict[str, EndpointDefinition[Any, Any]] = {}
82        for name, method in inspect.getmembers(
83            self, predicate=inspect.ismethod
84        ):
85            if hasattr(method, "_endpoint_def"):
86                endpoint_def: EndpointDefinition = method._endpoint_def  # type: ignore
87                endpoints[name] = endpoint_def  # type: ignore
88        return endpoints
def get_endpoints(self):
77    def get_endpoints(self):
78        return self.__endpoints
def extract_endpoints(self):
80    def extract_endpoints(self):
81        endpoints: dict[str, EndpointDefinition[Any, Any]] = {}
82        for name, method in inspect.getmembers(
83            self, predicate=inspect.ismethod
84        ):
85            if hasattr(method, "_endpoint_def"):
86                endpoint_def: EndpointDefinition = method._endpoint_def  # type: ignore
87                endpoints[name] = endpoint_def  # type: ignore
88        return endpoints