/*
    $Source: /local/data/cvs/yellowbank/postgres/src/y_crypto/y_pgcrypto.c,v $
    $Revision: 1.1 $
    $State: Exp $
    $Date: 2007/11/20 00:56:11 $
    $Author: yrp001 $
    $Locker:  $

    Copyright 2007
    (Y) Yellowbank
    Ronald Peterson

    https://www.yellowbank.com/

    This file is part of y_pgcrypto.

    y_pgcrypto is free software; you can redistribute it and/or modify
    it under the terms of the GNU Affero GPL version 3.0.  These
    license terms can be found in the included file agpl-3.0.txt.
*/

//________________________________________________________________________
#include "y_clib.h"

#include "postgres.h"
#include "fmgr.h"
#include "utils/datetime.h"
// tuple building functions and macros
#include "access/heapam.h"
#include "funcapi.h"

// 8.2 -> 8.3 macro change
#ifndef SET_VARSIZE
#define SET_VARSIZE(v,l) (VARATT_SIZEP(v) = (l))
#endif


//________________________________________________________________________
// utility functions
static
char
*tp2cp_repalloc( char *cp, const text *tp ) {
   int len;
   char *tmp;
   len = VARSIZE(tp) - VARHDRSZ;
   if( cp ) {
      memset( cp, '\0', strlen( cp ) + 1 );
      pfree( cp );
   }
   tmp = (char *)palloc( len + 1 );
   if( tmp == NULL ) {
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("tp2cp_repalloc: "
                      "error reallocating character string")));
      return NULL;
   }
   cp = tmp;
   if( ! memcpy( cp, VARDATA(tp), len ) ) { 
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("tp2cp_repalloc: "
                      "error copying data")));
      return NULL;
   }
   if( ! memset( cp + len, '\0', 1 ) ) {
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("tp2cp_repalloc: "
                      "error setting terminal null")));
      return NULL;
   }
   return cp;
}

// expects null terminated string input
static
text
*cp2tp_repalloc( text *tp, const char *cp ) {
   int len;
   text *tmp;
   len = strlen( cp );
   if( tp ) {
      memset( tp, '\0', VARSIZE(tp) + VARHDRSZ );
      pfree( tp );
   }
   tmp = (text *)palloc( len + VARHDRSZ );
   if( tmp == NULL ) {
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("cp2tp_repalloc: "
                      "error reallocating text")));
      return NULL;
   }
   tp = tmp;
   SET_VARSIZE(tp, len + VARHDRSZ);
   if( ! memcpy( VARDATA(tp), cp, len ) ) {
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("cp2tp_repalloc: "
                      "error copying data")));
      return NULL;
   }

   return tp;
}


//________________________________________________________________________
#ifdef PG_MODULE_MAGIC
PG_MODULE_MAGIC;
#endif


//________________________________________________________________________
Datum y_mhash( PG_FUNCTION_ARGS );
PG_FUNCTION_INFO_V1( y_mhash );
Datum
y_mhash( PG_FUNCTION_ARGS )
{
   bytea* data;
   text* type;
   int ret;
   int len;
   uint hash_size;
   bytea* result;
   char* hash_type;
   y_octstr hash;
	
   if( PG_ARGISNULL(0) ||
       PG_ARGISNULL(1) )
   {
      PG_RETURN_NULL();
   }
   data = PG_GETARG_BYTEA_P(0);
   type = PG_GETARG_TEXT_P(1);

   len = VARSIZE( type ) - VARHDRSZ;
   hash_type = (char*)palloc( len + 1 );
   memcpy( hash_type, VARDATA(type), len );
   memset( hash_type + len, '\0', 1 );

   y_octstr_init( hash );

   ret = y_calc_mhash( hash,
                       hash_type,
                       VARDATA(data),
                       VARSIZE(data) - VARHDRSZ,
                       &hash_size );

   if( ret != Y_HASH_OK ) {
      ereport(ERROR,
              (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
               errmsg("hash type not available")));
      PG_RETURN_NULL();
   }

   result = (bytea*)palloc( VARHDRSZ + hash_size );
   SET_VARSIZE( result, VARHDRSZ + hash_size );
   memcpy( VARDATA( result ), Y_OS_STR( hash ), hash_size );

   y_octstr_clear( hash );
   pfree( hash_type );

   PG_RETURN_BYTEA_P( result );
}


//________________________________________________________________________
Datum y_pg_generate_rsa_keys( PG_FUNCTION_ARGS );
PG_FUNCTION_INFO_V1( y_pg_generate_rsa_keys );
Datum
y_pg_generate_rsa_keys( PG_FUNCTION_ARGS )
{
   int32 len;
   int32 bitlen;
   y_full_key akey;
   y_octstr modulus;
   y_octstr public_exponent;
   y_octstr private_exponent;

   char **vals;
   TupleDesc td;
   HeapTuple ht;
   AttInMetadata* aim;
   Datum result;

   if( PG_ARGISNULL(0) )
   {
      PG_RETURN_NULL();
   }
   bitlen = PG_GETARG_INT32(0);

   y_full_key_init( &akey );
   y_octstr_init( modulus );
   y_octstr_init( public_exponent );
   y_octstr_init( private_exponent );

   y_generate_rsa_keys( &akey, 256, bitlen );

   vals = (char**)palloc( sizeof(char*) * 3 );

   // convert key parts to strings
   len = mpz_sizeinbase( Y_FULL_KEY_MOD(&akey), 16 ) + 1;
   vals[0] = (char *)palloc( len );
   gmp_snprintf( vals[0], len, "%Zx", Y_FULL_KEY_MOD(&akey) );

   len = mpz_sizeinbase( Y_FULL_KEY_PUB(&akey), 16 ) + 1;
   vals[1] = (char *)palloc( len );
   gmp_snprintf( vals[1], len, "%Zx", Y_FULL_KEY_PUB(&akey) );

   len = mpz_sizeinbase( Y_FULL_KEY_PRIV(&akey), 16 ) + 1;
   vals[2] = (char *)palloc( len );
   gmp_snprintf( vals[2], len, "%Zx", Y_FULL_KEY_PRIV(&akey) );

   /* Build a tuple descriptor for our result type */
   if (get_call_result_type(fcinfo, NULL, &td) != TYPEFUNC_COMPOSITE) {
      ereport(ERROR,
              (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
               errmsg("function returning record called in context "
                      "that cannot accept type record")));
      PG_RETURN_NULL();
   }

   // Make a persistant copy.
   td = CreateTupleDescCopy( td );

   aim = TupleDescGetAttInMetadata( td );
   ht = BuildTupleFromCStrings( aim, vals );

   /* make the tuple into a datum */
   result = HeapTupleGetDatum( ht );

   y_full_key_clear( &akey );
   y_octstr_clear( modulus );
   y_octstr_clear( public_exponent );
   y_octstr_clear( private_exponent );

   pfree( vals[0] );
   pfree( vals[1] );
   pfree( vals[2] );
   pfree( vals );

   PG_RETURN_DATUM( result );
}


//________________________________________________________________________
Datum y_pg_rsaes_oaep_encrypt( PG_FUNCTION_ARGS );
PG_FUNCTION_INFO_V1( y_pg_rsaes_oaep_encrypt );
Datum
y_pg_rsaes_oaep_encrypt( PG_FUNCTION_ARGS )
{
   text *messageHex;
   HeapTupleHeader publicKeyIN;
   text *hashType;
   text *label;

   bool isNull;
   text *modulusHex;
   text *publicExponentHex;

   char *messageHexCP = NULL;
   char *modulusHexCP = NULL;
   char *publicExponentHexCP = NULL;
   char *hashTypeCP = NULL;
   char *labelCP = NULL;

   y_octstr cipherTextOS;
   y_octstr messageOS;
   y_part_key publicKey;

   char *tmpCP = NULL;
   char *resultCP = NULL;
   text *result = NULL;

   if( PG_ARGISNULL(0) ||
       PG_ARGISNULL(1) ||
       PG_ARGISNULL(2) ||
       PG_ARGISNULL(3) )
   {
      PG_RETURN_NULL();
   }
   messageHex = PG_GETARG_TEXT_P(0);
   publicKeyIN = PG_GETARG_HEAPTUPLEHEADER(1);
   hashType = PG_GETARG_TEXT_P(2);
   label = PG_GETARG_TEXT_P(3);

   modulusHex = (text *)GetAttributeByNum( publicKeyIN, 1, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }
   publicExponentHex = (text *)GetAttributeByNum( publicKeyIN, 2, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }

   y_octstr_init( cipherTextOS );
   y_octstr_init( messageOS );
   y_part_key_init( &publicKey );

   // convert PostgreSQL datatype into types our functions need
   messageHexCP = tp2cp_repalloc( messageHexCP, messageHex );
   modulusHexCP = tp2cp_repalloc( modulusHexCP, modulusHex );
   publicExponentHexCP = tp2cp_repalloc( publicExponentHexCP, publicExponentHex );
   hashTypeCP = tp2cp_repalloc( hashTypeCP, hashType );
   labelCP = tp2cp_repalloc( labelCP, label );

   y_octstr_set_from_cp_hex( messageOS, messageHexCP );

   // set up public key
   mpz_set_str( Y_PART_KEY_MOD(&publicKey), modulusHexCP, 16 );
   mpz_set_str( Y_PART_KEY_EXP(&publicKey), publicExponentHexCP, 16 );

   y_rsaes_oaep_encrypt( cipherTextOS,
                         messageOS,
                         &publicKey,
                         hashTypeCP,
                         labelCP );

   tmpCP = y_cp_hex_realloc_from_octstr( resultCP, cipherTextOS );
   if( tmpCP == NULL ) {
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("error allocating result character string")));
      PG_RETURN_NULL();
   }
   resultCP = tmpCP;

   result = cp2tp_repalloc( result, resultCP );

   y_octstr_clear( cipherTextOS );
   y_octstr_clear( messageOS );
   y_part_key_clear( &publicKey );

   pfree( messageHexCP );
   pfree( modulusHexCP );
   pfree( publicExponentHexCP );
   pfree( hashTypeCP );
   pfree( labelCP );
// don't free this
//   pfree( resultCP );

   PG_RETURN_TEXT_P( result );
}


//________________________________________________________________________
Datum y_pg_rsaes_oaep_decrypt( PG_FUNCTION_ARGS );
PG_FUNCTION_INFO_V1( y_pg_rsaes_oaep_decrypt );
Datum
y_pg_rsaes_oaep_decrypt( PG_FUNCTION_ARGS )
{
   text *cipherTextHex;
   HeapTupleHeader privateKeyIN;
   text *hashType;
   text *label;

   bool isNull;
   text *modulusHex;
   text *privateExponentHex;

   char *cipherTextHexCP = NULL;
   char *modulusHexCP = NULL;
   char *privateExponentHexCP = NULL;
   char *hashTypeCP = NULL;
   char *labelCP = NULL;

   y_octstr cipherTextOS;
   y_octstr messageOS;
   y_part_key privateKey;

   char *tmpCP = NULL;
   char *resultCP = NULL;
   text *result = NULL;

   if( PG_ARGISNULL(0) ||
       PG_ARGISNULL(1) ||
       PG_ARGISNULL(2) ||
       PG_ARGISNULL(3) )
   {
      PG_RETURN_NULL();
   }
   cipherTextHex = PG_GETARG_TEXT_P(0);
   privateKeyIN = PG_GETARG_HEAPTUPLEHEADER(1);
   hashType = PG_GETARG_TEXT_P(2);
   label = PG_GETARG_TEXT_P(3);

   modulusHex = (text *)GetAttributeByNum( privateKeyIN, 1, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }
   privateExponentHex = (text *)GetAttributeByNum( privateKeyIN, 2, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }

   y_octstr_init( cipherTextOS );
   y_octstr_init( messageOS );
   y_part_key_init( &privateKey );

   // convert PostgreSQL datatype into types our functions need
   cipherTextHexCP = tp2cp_repalloc( cipherTextHexCP, cipherTextHex );
   modulusHexCP = tp2cp_repalloc( modulusHexCP, modulusHex );
   privateExponentHexCP = tp2cp_repalloc( privateExponentHexCP, privateExponentHex );
   hashTypeCP = tp2cp_repalloc( hashTypeCP, hashType );
   labelCP = tp2cp_repalloc( labelCP, label );
   y_octstr_set_from_cp_hex( cipherTextOS, cipherTextHexCP );

   // set up private key
   mpz_set_str( Y_PART_KEY_MOD(&privateKey), modulusHexCP, 16 );
   mpz_set_str( Y_PART_KEY_EXP(&privateKey), privateExponentHexCP, 16 );

   y_rsaes_oaep_decrypt( messageOS,
                         cipherTextOS,
                         &privateKey,
                         hashTypeCP,
                         labelCP );

   tmpCP = y_cp_hex_realloc_from_octstr( resultCP, messageOS );
   if( tmpCP == NULL ) {
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("error allocating result character string")));
      PG_RETURN_NULL();
   }
   resultCP = tmpCP;

   result = cp2tp_repalloc( result, resultCP );

   y_octstr_clear( cipherTextOS );
   y_octstr_clear( messageOS );
   y_part_key_clear( &privateKey );
   pfree( cipherTextHexCP );
   pfree( modulusHexCP );
   pfree( privateExponentHexCP );
   pfree( hashTypeCP );
   pfree( labelCP );
//   pfree( resultCP );

   PG_RETURN_TEXT_P( result );
}


//________________________________________________________________________
Datum y_pg_rsassa_pss_sign( PG_FUNCTION_ARGS );
PG_FUNCTION_INFO_V1( y_pg_rsassa_pss_sign );
Datum
y_pg_rsassa_pss_sign( PG_FUNCTION_ARGS )
{
   text *messageHex;
   HeapTupleHeader privateKeyIN;
   text *hashType;
   int32 saltLen;

   bool isNull;
   text *modulusHex;
   text *privateExponentHex;

   char *messageHexCP = NULL;
   char *modulusHexCP = NULL;
   char *privateExponentHexCP = NULL;
   char *hashTypeCP = NULL;

   y_octstr messageOS;
   y_octstr signatureOS;
   y_part_key privateKey;

   char *tmpCP = NULL;
   char *resultCP = NULL;
   text *result = NULL;

   if( PG_ARGISNULL(0) ||
       PG_ARGISNULL(1) ||
       PG_ARGISNULL(2) ||
       PG_ARGISNULL(3) )
   {
      PG_RETURN_NULL();
   }
   messageHex = PG_GETARG_TEXT_P(0);
   privateKeyIN = PG_GETARG_HEAPTUPLEHEADER(1);
   hashType = PG_GETARG_TEXT_P(2);
   saltLen = PG_GETARG_INT32(3);

   modulusHex = (text *)GetAttributeByNum( privateKeyIN, 1, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }
   privateExponentHex = (text *)GetAttributeByNum( privateKeyIN, 2, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }

   y_octstr_init( messageOS );
   y_octstr_init( signatureOS );
   y_part_key_init( &privateKey );

   // convert PostgreSQL datatype into types our functions need
   messageHexCP = tp2cp_repalloc( messageHexCP, messageHex );
   modulusHexCP = tp2cp_repalloc( modulusHexCP, modulusHex );
   privateExponentHexCP = tp2cp_repalloc( privateExponentHexCP, privateExponentHex );
   hashTypeCP = tp2cp_repalloc( hashTypeCP, hashType );

   y_octstr_set_from_cp_hex( messageOS, messageHexCP );

   // set up private key
   mpz_set_str( Y_PART_KEY_MOD(&privateKey), modulusHexCP, 16 );
   mpz_set_str( Y_PART_KEY_EXP(&privateKey), privateExponentHexCP, 16 );

   y_rsassa_pss_sign( signatureOS,
                      messageOS,
                      &privateKey,
                      hashTypeCP,
                      saltLen );

   tmpCP = y_cp_hex_realloc_from_octstr( resultCP, signatureOS );
   if( tmpCP == NULL ) {
      ereport(ERROR,
              (errcode(ERRCODE_INTERNAL_ERROR),
               errmsg("error allocating result character string")));
      PG_RETURN_NULL();
   }
   resultCP = tmpCP;

   result = cp2tp_repalloc( result, resultCP );

   y_octstr_clear( signatureOS );
   y_octstr_clear( messageOS );
   y_part_key_clear( &privateKey );
   pfree( messageHexCP );
   pfree( modulusHexCP );
   pfree( privateExponentHexCP );
   pfree( hashTypeCP );
//   pfree( resultCP );

   PG_RETURN_TEXT_P( result );
}


//________________________________________________________________________
Datum y_pg_rsassa_pss_verify( PG_FUNCTION_ARGS );
PG_FUNCTION_INFO_V1( y_pg_rsassa_pss_verify );
Datum
y_pg_rsassa_pss_verify( PG_FUNCTION_ARGS )
{
   text *messageHex;
   text *signatureHex;
   HeapTupleHeader publicKeyIN;
   text *hashType;
   int32 saltLen;

   bool isNull;
   text *modulusHex;
   text *publicExponentHex;

   char *messageHexCP = NULL;
   char *signatureHexCP = NULL;
   char *modulusHexCP = NULL;
   char *publicExponentHexCP = NULL;
   char *hashTypeCP = NULL;

   y_octstr messageOS;
   y_octstr signatureOS;
   y_part_key publicKey;

   int result;

   if( PG_ARGISNULL(0) ||
       PG_ARGISNULL(1) ||
       PG_ARGISNULL(2) ||
       PG_ARGISNULL(3) ||
       PG_ARGISNULL(4) )
   {
      PG_RETURN_NULL();
   }

   messageHex = PG_GETARG_TEXT_P(0);
   signatureHex = PG_GETARG_TEXT_P(1);
   publicKeyIN = PG_GETARG_HEAPTUPLEHEADER(2);
   hashType = PG_GETARG_TEXT_P(3);
   saltLen = PG_GETARG_INT32(4);

   modulusHex = (text *)GetAttributeByNum( publicKeyIN, 1, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }
   publicExponentHex = (text *)GetAttributeByNum( publicKeyIN, 2, &isNull );
   if( isNull ) { PG_RETURN_NULL(); }

   y_octstr_init( messageOS );
   y_octstr_init( signatureOS );
   y_part_key_init( &publicKey );

   // convert PostgreSQL datatype into types our functions need
   messageHexCP = tp2cp_repalloc( messageHexCP, messageHex );
   signatureHexCP = tp2cp_repalloc( signatureHexCP, signatureHex );
   modulusHexCP = tp2cp_repalloc( modulusHexCP, modulusHex );
   publicExponentHexCP = tp2cp_repalloc( publicExponentHexCP, publicExponentHex );
   hashTypeCP = tp2cp_repalloc( hashTypeCP, hashType );

   y_octstr_set_from_cp_hex( messageOS, messageHexCP );
   y_octstr_set_from_cp_hex( signatureOS, signatureHexCP );

   // set up private key
   mpz_set_str( Y_PART_KEY_MOD(&publicKey), modulusHexCP, 16 );
   mpz_set_str( Y_PART_KEY_EXP(&publicKey), publicExponentHexCP, 16 );

   result = y_rsassa_pss_verify( messageOS,
                                 signatureOS,
                                 &publicKey,
                                 hashTypeCP,
                                 saltLen );

   y_octstr_clear( messageOS );
   y_octstr_clear( signatureOS );
   y_part_key_clear( &publicKey );
   pfree( messageHexCP );
   pfree( signatureHexCP );
   pfree( modulusHexCP );
   pfree( publicExponentHexCP );
   pfree( hashTypeCP );

//   PG_RETURN_BOOL( (result == VALID_SIGNATURE) ? true : false );
   PG_RETURN_BOOL( result == VALID_SIGNATURE );
}
