Edit on GitHub

communex.encryption

  1import struct
  2
  3import cryptography.hazmat.primitives.serialization as crypt_serialization
  4from cryptography.exceptions import InvalidSignature
  5from cryptography.hazmat.primitives.asymmetric import padding, rsa
  6
  7# def int_from_hex_bytes_be(data: str) -> int:
  8#     # return int.from_bytes(bytes.fromhex(data), 'big')
  9#     return int(data, 16)
 10
 11
 12def bytes_from_hex(data: str) -> bytes:
 13    return bytes.fromhex(data)
 14
 15
 16def encrypt_weights(
 17    key: tuple[bytes, bytes],
 18    data: list[tuple[int, int]],
 19    validator_key: list[int],
 20) -> bytes:
 21    # Create RSA public key
 22    public_numbers = rsa.RSAPublicNumbers(
 23        n=int.from_bytes(key[0], "big"),
 24        e=int.from_bytes(key[1], "big"),
 25    )
 26    rsa_key = public_numbers.public_key()
 27
 28    # Encode data
 29    encoded = (
 30        (len(data)).to_bytes(4, "big")
 31        + b"".join(
 32            uid.to_bytes(2, "big") + weight.to_bytes(2, "big")
 33            for uid, weight in data
 34        )
 35        + bytes(validator_key)
 36    )
 37
 38    # Calculate max chunk size
 39    max_chunk_size = rsa_key.key_size // 8 - 11  # 11 bytes for PKCS1v15 padding
 40
 41    # Encrypt in chunks
 42    encrypted = b""
 43    for i in range(0, len(encoded), max_chunk_size):
 44        chunk = encoded[i : i + max_chunk_size]
 45        encrypted += rsa_key.encrypt(chunk, padding.PKCS1v15())
 46
 47    return encrypted
 48
 49
 50def decrypt_weights(
 51    private_key: rsa.RSAPrivateKey, encrypted: bytes
 52) -> tuple[list[tuple[int, int]], list[int]] | None:
 53    # Decrypt in chunks
 54    decrypted = b""
 55    chunk_size = private_key.key_size // 8
 56    for i in range(0, len(encrypted), chunk_size):
 57        chunk = encrypted[i : i + chunk_size]
 58        try:
 59            decrypted += private_key.decrypt(chunk, padding.PKCS1v15())
 60        except InvalidSignature:
 61            return None
 62
 63    # Read the decrypted data
 64    cursor = 0
 65
 66    def read_u32() -> int | None:
 67        nonlocal cursor
 68        if cursor + 4 > len(decrypted):
 69            return None
 70        value = struct.unpack(">I", decrypted[cursor : cursor + 4])[0]
 71        cursor += 4
 72        return value
 73
 74    def read_u16() -> int | None:
 75        nonlocal cursor
 76        if cursor + 2 > len(decrypted):
 77            return None
 78        value = struct.unpack(">H", decrypted[cursor : cursor + 2])[0]
 79        cursor += 2
 80        return value
 81
 82    length = read_u32()
 83    if length is None:
 84        return None
 85
 86    weights: list[tuple[int, int]] = []
 87    for _ in range(length):
 88        uid = read_u16()
 89        weight = read_u16()
 90        if uid is None or weight is None:
 91            return None
 92        weights.append((uid, weight))
 93
 94    key = list(decrypted[cursor:])
 95
 96    return weights, key
 97
 98
 99def _test():
100    weights = [(1, 2), (3, 4)]
101    validator_key = [11, 22, 33, 44]
102
103    pub_key_n_hex = "d740d02640e98befc21238399205e9dba5b711d237d06df06cc2af6f92c39d76292e90c34d04939e3a3e18520482b762be6ae8859f0f9f13d075b000a8892bfc0225729bc9fcd84d2c4149347231557f4678241aaf3c080a47c33aa1f4c90aee7cacf694ddeebe9abbbf231a9fdba410afbbc8f7a3e9f776a26504ce4a1982e9"
104
105    pub_key_e_hex = "010001"
106
107    rsa_key_pem = b"""
108-----BEGIN RSA PRIVATE KEY-----
109MIICXQIBAAKBgQDXQNAmQOmL78ISODmSBenbpbcR0jfQbfBswq9vksOddikukMNN
110BJOeOj4YUgSCt2K+auiFnw+fE9B1sACoiSv8AiVym8n82E0sQUk0cjFVf0Z4JBqv
111PAgKR8M6ofTJCu58rPaU3e6+mru/Ixqf26QQr7vI96Pp93aiZQTOShmC6QIDAQAB
112AoGBAIixPf2s5yLYZLPRRK34V2QGvlTw3ETeK/nFQEdoOhT6fnh1sbBtIZkvf1NO
113clLYRjqKBZMlSXRJzu2NkT11rpm1hTTuc99w0SjZDHFpj0TppXtagmJYwHBYt5Ac
114oNan6ALTlUbxEHtIj4rGghJAJBOVTq0pi8PdVgAQgq3cArUBAkEA2f9SFOmDWN7w
115PO6yHZfj7e8i65W8v4HZXV/EWv3kCZW5KZsM3OBlqqx1txIljxF146C7ZpBLLQEK
116ubVOqKqPsQJBAPzHBuczD6GziSbN9sjgj4sAxGwExp8Z747rxGVlB56ak68aqFt1
117GDuwib0NIrrDUuGlQUKIWUm6amSwu/UJbLkCQDsZS8Bdmf0y20A5mdIKBoHPrdDe
118VEA6zJnSx6G/aN3sWDleTntm3kkJ3hPWeJYzrpkaTxO8FJVLzgOQkpWJP9ECQQD1
119q0EsRlX05BZx3k7w4D7h67b6/JFFY+GNV9qiaNRE8xqBXjkt2dnZeTQExtVwChFt
120ODz6uqV8oG5yucmS1rwRAkA1KjcZDPBRZ05wlf8VZuJjWYIRbVx3PBpQJPbtW7Vg
121fvRuW5JF+WZtGddyU4751JNNNhmwbwGmsmphy7EOHHaC
122-----END RSA PRIVATE KEY-----
123"""
124
125    # encrypted_hex = "acb87f05bb9d8bd6fd91614a8cfe44bd383d5d27ddd44f58788dc01775123413157f4040dbf8be719c160df01bbc1ea01e321a929990c558c29deb89ca348ed049f04a3ad1470a914ea884114b2889a1f1dce2f42542167d85d129bba44b6f71e6bc197d048fbd0ea08d013d9279c26d675bb7fba63928fd2dc13f886879c629"
126    # encrypted_ref = bytes_from_hex(encrypted_hex)
127
128    # ======================================================================== #
129
130    private_key = crypt_serialization.load_pem_private_key(
131        rsa_key_pem, password=None
132    )
133    assert isinstance(private_key, rsa.RSAPrivateKey)
134    pub_key_n = bytes_from_hex(pub_key_n_hex)
135    pub_key_e = bytes_from_hex(pub_key_e_hex)
136
137    encrypted = encrypt_weights((pub_key_n, pub_key_e), weights, validator_key)
138
139    decrypted = decrypt_weights(private_key, encrypted)
140    assert decrypted is not None
141    (weights_dec, key_dec) = decrypted
142    print(f"weights_dec = {weights_dec}")
143    print(f"key_dec = {key_dec}")
144
145    assert weights_dec == weights
146    assert key_dec == validator_key
147
148
149if __name__ == "__main__":
150    _test()
def bytes_from_hex(data: str) -> bytes:
13def bytes_from_hex(data: str) -> bytes:
14    return bytes.fromhex(data)
def encrypt_weights( key: tuple[bytes, bytes], data: list[tuple[int, int]], validator_key: list[int]) -> bytes:
17def encrypt_weights(
18    key: tuple[bytes, bytes],
19    data: list[tuple[int, int]],
20    validator_key: list[int],
21) -> bytes:
22    # Create RSA public key
23    public_numbers = rsa.RSAPublicNumbers(
24        n=int.from_bytes(key[0], "big"),
25        e=int.from_bytes(key[1], "big"),
26    )
27    rsa_key = public_numbers.public_key()
28
29    # Encode data
30    encoded = (
31        (len(data)).to_bytes(4, "big")
32        + b"".join(
33            uid.to_bytes(2, "big") + weight.to_bytes(2, "big")
34            for uid, weight in data
35        )
36        + bytes(validator_key)
37    )
38
39    # Calculate max chunk size
40    max_chunk_size = rsa_key.key_size // 8 - 11  # 11 bytes for PKCS1v15 padding
41
42    # Encrypt in chunks
43    encrypted = b""
44    for i in range(0, len(encoded), max_chunk_size):
45        chunk = encoded[i : i + max_chunk_size]
46        encrypted += rsa_key.encrypt(chunk, padding.PKCS1v15())
47
48    return encrypted
def decrypt_weights( private_key: cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey, encrypted: bytes) -> tuple[list[tuple[int, int]], list[int]] | None:
51def decrypt_weights(
52    private_key: rsa.RSAPrivateKey, encrypted: bytes
53) -> tuple[list[tuple[int, int]], list[int]] | None:
54    # Decrypt in chunks
55    decrypted = b""
56    chunk_size = private_key.key_size // 8
57    for i in range(0, len(encrypted), chunk_size):
58        chunk = encrypted[i : i + chunk_size]
59        try:
60            decrypted += private_key.decrypt(chunk, padding.PKCS1v15())
61        except InvalidSignature:
62            return None
63
64    # Read the decrypted data
65    cursor = 0
66
67    def read_u32() -> int | None:
68        nonlocal cursor
69        if cursor + 4 > len(decrypted):
70            return None
71        value = struct.unpack(">I", decrypted[cursor : cursor + 4])[0]
72        cursor += 4
73        return value
74
75    def read_u16() -> int | None:
76        nonlocal cursor
77        if cursor + 2 > len(decrypted):
78            return None
79        value = struct.unpack(">H", decrypted[cursor : cursor + 2])[0]
80        cursor += 2
81        return value
82
83    length = read_u32()
84    if length is None:
85        return None
86
87    weights: list[tuple[int, int]] = []
88    for _ in range(length):
89        uid = read_u16()
90        weight = read_u16()
91        if uid is None or weight is None:
92            return None
93        weights.append((uid, weight))
94
95    key = list(decrypted[cursor:])
96
97    return weights, key