1#!/usr/bin/env python3
2# SPDX-License-Identifier: BSD-2-Clause
3#
4# Copyright Amazon.com Inc. or its affiliates
5#
6import typing
7
8import boto3
9
10from cryptography.hazmat.primitives import hashes, serialization
11from cryptography.hazmat.primitives.asymmetric import (
12    AsymmetricSignatureContext,
13    utils as asym_utils,
14)
15from cryptography.hazmat.primitives.asymmetric.padding import (
16    AsymmetricPadding,
17    PKCS1v15,
18    PSS,
19)
20from cryptography.hazmat.primitives.asymmetric.rsa import (
21    RSAPrivateKey,
22    RSAPrivateNumbers,
23    RSAPublicKey,
24)
25
26
27class _RSAPrivateKeyInKMS(RSAPrivateKey):
28
29    def __init__(self, arn):
30        self.arn = arn
31        self.client = boto3.client('kms')
32        response = self.client.get_public_key(KeyId=self.arn)
33
34        # Parse public key
35        self.public_key = serialization.load_der_public_key(
36                response['PublicKey'])
37
38    @property
39    def key_size(self):
40        return self.public_key.key_size
41
42    def public_key(self) -> RSAPublicKey:
43        return self.public_key
44
45    def sign(self, data: bytes, padding: AsymmetricPadding,
46             algorithm: typing.Union[asym_utils.Prehashed,
47                                     hashes.HashAlgorithm]
48             ) -> bytes:
49        if isinstance(algorithm, asym_utils.Prehashed):
50            message_type = 'DIGEST'
51        else:
52            message_type = 'RAW'
53
54        if isinstance(padding, PSS):
55            signing_alg = 'RSASSA_PSS_'
56        elif isinstance(padding, PKCS1v15):
57            signing_alg = 'RSASSA_PKCS1_V1_5_'
58        else:
59            raise TypeError("Unsupported padding")
60
61        if (isinstance(algorithm._algorithm, hashes.SHA256) or
62                isinstance(algorithm, hashes.SHA256)):
63            signing_alg += 'SHA_256'
64        elif (isinstance(algorithm._algorithm, hashes.SHA384) or
65                isinstance(algorithm, hashes.SHA384)):
66            signing_alg += 'SHA_384'
67        elif (isinstance(algorithm._algorithm, hashes.SHA512) or
68                isinstance(algorithm, hashes.SHA512)):
69            signing_alg += 'SHA_512'
70        else:
71            raise TypeError("Unsupported hashing algorithm")
72
73        response = self.client.sign(
74                KeyId=self.arn, Message=data,
75                MessageType=message_type,
76                SigningAlgorithm=signing_alg)
77
78        return response['Signature']
79
80    # No need to implement these functions so we raise an exception
81    def signer(
82        self, padding: AsymmetricPadding, algorithm: hashes.HashAlgorithm
83    ) -> AsymmetricSignatureContext:
84        raise NotImplementedError
85
86    def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
87        raise NotImplementedError
88
89    def private_numbers(self) -> RSAPrivateNumbers:
90        raise NotImplementedError
91
92    def private_bytes(
93        self,
94        encoding: serialization.Encoding,
95        format: serialization.PrivateFormat,
96        encryption_algorithm: serialization.KeySerializationEncryption
97    ) -> bytes:
98        raise NotImplementedError
99