#include "CryptOpenSsl.h"

// ---------------------------------------------------------------------------------------------

String ToString(BIGNUM *b) {
	String ret(0, BN_num_bytes(b));
	BN_bn2bin(b, (uint8 *)~ret);
	return ret;
}

BIGNUM * ToBigNum(const String& str) {
	return BN_bin2bn((uint8 *)~str, str.GetCount(), BN_new());
}

String ToString(BIO *bp) {
	BUF_MEM *buffer;
	BIO_get_mem_ptr(bp, &buffer);
	String ret(0, buffer->length);
	memcpy((uint8 *)~ret, buffer->data, buffer->length);
	BIO_free(bp);
	
	return ret;
}

BIO * ToBIO(const String &str) {
	return BIO_new_mem_buf((uint8 *)~str, str.GetCount());
}

// ---------------------------------------------------------------------------------------------

void Rsa::GenerateKeyPair(int bits, int exponent) {
	if(rsa) RSA_free(rsa);
	
	rsa = RSA_generate_key(bits, exponent, NULL, NULL);	
}

String Rsa::PrivateKeyToPem() {
	ASSERT(rsa);
	
	BIO *bp = BIO_new(BIO_s_mem());
	PEM_write_bio_RSAPrivateKey(bp, rsa, NULL, NULL, 0, NULL, NULL);
	
	return ToString(bp);
}

String Rsa::PublicKeyToPem() {
	ASSERT(rsa);
	
	BIO *bp = BIO_new(BIO_s_mem());
	PEM_write_bio_RSAPublicKey(bp, rsa);
	
	return ToString(bp);
}

void Rsa::PrivateKeyFromPem(const String &pem) {
	if(rsa) {
		RSA_free(rsa);
		rsa = NULL;
	}	
	BIO *bp = ToBIO(pem);
	PEM_read_bio_RSAPrivateKey(bp, &rsa, NULL, NULL);
	BIO_free(bp);
}

void Rsa::PrivateKeyFromPem(uint8 *d, int l) {
	if(rsa) {
		RSA_free(rsa);
		rsa = NULL;
	}
	BIO *bp = BIO_new_mem_buf(d, l);
	PEM_read_bio_RSAPrivateKey(bp, &rsa, NULL, NULL);
	BIO_free(bp);
}

void Rsa::PublicKeyFromPem(const String &pem) {
	if(rsa) {
		RSA_free(rsa);
		rsa = NULL;
	}
	BIO *bp = ToBIO(pem);
	PEM_read_bio_RSAPublicKey(bp, &rsa, NULL, NULL);
	BIO_free(bp);
}

void Rsa::PublicKeyFromPem(uint8 *d, int l) {
	if(rsa) {
		RSA_free(rsa);
		rsa = NULL;
	}
	BIO *bp = BIO_new_mem_buf(d, l);
	PEM_read_bio_RSAPublicKey(bp, &rsa, NULL, NULL);
	BIO_free(bp);
}

String Rsa::SignSHA(const String &msg) {
	ASSERT(rsa);
	
	String ret(0, RSA_size(rsa));
	unsigned int len;
	uint8 h[20];
	
	SHA1((uint8 *)~msg, msg.GetCount(), h); 
	
	RSA_sign(NID_sha1, h, 20, (uint8 *)~ret, &len, rsa);	
	ret.Trim(len);
	
	return ret;
}

String Rsa::Decrypt(const String &msg, int padding) {
	ASSERT(rsa);

	String ret(0, RSA_size(rsa));
	int len = RSA_private_decrypt(msg.GetCount(), (uint8 *)~msg, (uint8 *)~ret, rsa, padding);	
	ret.Trim(len);
	
	return ret;
}

bool Rsa::VerifySHA(const String &msg, const String &sig) {
	ASSERT(rsa);

	uint8 h[20];
	
	SHA1((uint8 *)~msg, msg.GetCount(), h); 

	return RSA_verify(NID_sha1, h, 20, (uint8 *)~sig, sig.GetCount(), rsa);
}

String Rsa::Encrypt(const String &msg, int padding) {
	ASSERT(rsa);
	ASSERT(padding != RSA_NO_PADDING || (msg.GetCount() == MaxMsgCount(padding)));
	ASSERT(padding == RSA_NO_PADDING || (msg.GetCount() < MaxMsgCount(padding)));

	String ret(0, RSA_size(rsa));
	int len = RSA_public_encrypt(msg.GetCount(), (uint8 *)~msg, (uint8 *)~ret, rsa, padding);	
	ret.Trim(len);

	return ret;
}

int Rsa::MaxMsgCount(int padding) {
	ASSERT(rsa);
	
	int ret = RSA_size(rsa);
	switch(padding) {
		case RSA_PKCS1_PADDING:
		case RSA_SSLV23_PADDING:
			return ret - 11;
		case RSA_PKCS1_OAEP_PADDING:
			return ret - 41;
		case RSA_NO_PADDING:
			return ret;
	}
	return -1;
}

void Rsa::Serialize(Stream &s) {
	if(s.IsLoading()) {
		String t;
		s % t;
		PrivateKeyFromPem(t);
	} else {
		String t = PrivateKeyToPem();
		s % t;
	}			
}

// ---------------------------------------------------------------------------------------------
#ifdef _DEBUG
TEST(Rsa) {
	Rsa rsa;
	rsa.GenerateKeyPair(512);
	
	String pri = rsa.PrivateKeyToPem();
	String pub = rsa.PublicKeyToPem();

	Rsa rsa2;
	rsa2.PrivateKeyFromPem(pri);

	String pri2 = rsa2.PrivateKeyToPem();
	String pub2 = rsa.PublicKeyToPem();	

	CHECK(pri == pri2);
	CHECK(pub == pub2);
	CHECK(rsa.VerifySHA("Kleiner Test", rsa2.SignSHA("Kleiner Test")));
	CHECK(rsa.Decrypt(rsa2.Encrypt("Kleiner Test")) == "Kleiner Test");
	CHECK(!rsa.VerifySHA("Kleiner Test", rsa2.SignSHA("@Kleiner Test")));
}
#endif

// ---------------------------------------------------------------------------------------------

