use rand::{CryptoRng, Rng};
use crate::ecdh;
use crate::ecdh::p256;
use crate::ecdh::p256_derive;
use crate::ecdh::x25519;
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum Algorithm {
X25519,
X448,
Secp256r1,
Secp384r1,
Secp521r1,
Kyber512,
Kyber768,
Kyber768X25519,
Kyber1024,
}
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
EcDhError(ecdh::Error),
KeyGen,
Encapsulate,
Decapsulate,
UnsupportedAlgorithm,
InvalidPrivateKey,
InvalidPublicKey,
InvalidCiphertext,
}
impl TryFrom<Algorithm> for ecdh::Algorithm {
type Error = &'static str;
fn try_from(value: Algorithm) -> Result<Self, Self::Error> {
match value {
Algorithm::X25519 => Ok(ecdh::Algorithm::X25519),
Algorithm::X448 => Ok(ecdh::Algorithm::X448),
Algorithm::Secp256r1 => Ok(ecdh::Algorithm::P256),
Algorithm::Secp384r1 => Ok(ecdh::Algorithm::P384),
Algorithm::Secp521r1 => Ok(ecdh::Algorithm::P521),
Algorithm::Kyber768X25519 => Ok(ecdh::Algorithm::X25519),
_ => Err("provided algorithm is not an ECDH algorithm"),
}
}
}
impl From<ecdh::Error> for Error {
fn from(value: ecdh::Error) -> Self {
Error::EcDhError(value)
}
}
pub enum PrivateKey {
X25519(x25519::PrivateKey),
P256(p256::PrivateKey),
}
pub enum PublicKey {
X25519(x25519::PublicKey),
P256(p256::PublicKey),
}
pub enum Ct {
X25519(x25519::PublicKey),
P256(p256::PublicKey),
}
pub enum Ss {
X25519(x25519::PublicKey),
P256(p256::PublicKey),
}
impl PrivateKey {
pub fn encode(&self) -> Vec<u8> {
match self {
PrivateKey::X25519(k) => k.0.to_vec(),
PrivateKey::P256(k) => k.0.to_vec(),
}
}
pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
match alg {
Algorithm::X25519 => bytes
.try_into()
.map_err(|_| Error::InvalidPrivateKey)
.map(|k| Self::X25519(k)),
Algorithm::Secp256r1 => bytes
.try_into()
.map_err(|_| Error::InvalidPrivateKey)
.map(|k| Self::P256(k)),
_ => Err(Error::UnsupportedAlgorithm),
}
}
}
impl PublicKey {
pub fn encode(&self) -> Vec<u8> {
match self {
PublicKey::X25519(k) => k.0.to_vec(),
PublicKey::P256(k) => k.0.to_vec(),
}
}
pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
match alg {
Algorithm::X25519 => bytes
.try_into()
.map_err(|_| Error::InvalidPublicKey)
.map(|k| Self::X25519(k)),
Algorithm::Secp256r1 => bytes
.try_into()
.map_err(|_| Error::InvalidPublicKey)
.map(|k| Self::P256(k)),
_ => Err(Error::UnsupportedAlgorithm),
}
}
}
impl Ss {
pub fn encode(&self) -> Vec<u8> {
match self {
Ss::X25519(k) => k.0.to_vec(),
Ss::P256(k) => k.0.to_vec(),
}
}
}
impl Ct {
pub fn encode(&self) -> Vec<u8> {
match self {
Ct::X25519(k) => k.0.to_vec(),
Ct::P256(k) => k.0.to_vec(),
}
}
pub fn decode(alg: Algorithm, bytes: &[u8]) -> Result<Self, Error> {
match alg {
Algorithm::X25519 => bytes
.try_into()
.map_err(|_| Error::InvalidCiphertext)
.map(|ct| Self::X25519(ct)),
Algorithm::Secp256r1 => bytes
.try_into()
.map_err(|_| Error::InvalidCiphertext)
.map(|ct| Self::P256(ct)),
_ => Err(Error::UnsupportedAlgorithm),
}
}
}
pub fn secret_to_public(alg: Algorithm, sk: impl AsRef<[u8]>) -> Result<Vec<u8>, Error> {
match alg {
Algorithm::X25519 | Algorithm::Secp256r1 => {
ecdh::secret_to_public(alg.try_into().unwrap(), sk.as_ref()).map_err(|e| e.into())
}
_ => Err(Error::UnsupportedAlgorithm),
}
}
pub fn key_gen(
alg: Algorithm,
rng: &mut (impl CryptoRng + Rng),
) -> Result<(PrivateKey, PublicKey), Error> {
match alg {
Algorithm::X25519 => ecdh::x25519_key_gen(rng)
.map_err(|e| e.into())
.map(|(private, public)| (PrivateKey::X25519(private), PublicKey::X25519(public))),
Algorithm::Secp256r1 => ecdh::p256_key_gen(rng)
.map_err(|e| e.into())
.map(|(private, public)| (PrivateKey::P256(private), PublicKey::P256(public))),
_ => Err(Error::UnsupportedAlgorithm),
}
}
pub fn encapsulate(pk: &PublicKey, rng: &mut (impl CryptoRng + Rng)) -> Result<(Ss, Ct), Error> {
match pk {
PublicKey::X25519(pk) => {
let (new_sk, new_pk) = ecdh::x25519_key_gen(rng)?;
let gxy = x25519::derive(pk, &new_sk)?;
Ok((Ss::X25519(gxy), Ct::X25519(new_pk)))
}
PublicKey::P256(pk) => {
let (new_sk, new_pk) = ecdh::p256_key_gen(rng)?;
let gxy = p256_derive(pk, &new_sk)?;
Ok((Ss::P256(gxy), Ct::P256(new_pk)))
}
}
}
pub fn decapsulate(ct: &Ct, sk: &PrivateKey) -> Result<Ss, Error> {
match ct {
Ct::X25519(ct) => {
let sk = if let PrivateKey::X25519(k) = sk {
k
} else {
return Err(Error::InvalidPrivateKey);
};
x25519::derive(ct, sk)
.map_err(|e| e.into())
.map(|k| Ss::X25519(k))
}
Ct::P256(ct) => {
let sk = if let PrivateKey::P256(k) = sk {
k
} else {
return Err(Error::InvalidPrivateKey);
};
p256_derive(ct, sk)
.map_err(|e| e.into())
.map(|k| Ss::P256(k))
}
}
}