/*
  $Source: /local/data/cvs/yellowbank/postgres/src/y_ntlm/y_ntlm.c,v $
  $Revision: 1.4 $
  $State: Exp $
  $Date: 2008/03/06 19:31:46 $
  $Author: yrp001 $
  $Locker:  $

  Copyright (c) 2006
  Ronald Peterson
  (Y) Yellowbank
  All rights reserved.  Applicable BSD license terms can be found in
  the associated LICENSE file.
*/

/* PostgreSQL includes */
#include "postgres.h"
#include "fmgr.h"
#include "utils/datetime.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <stdbool.h>
#include <sys/types.h>
#include <fcntl.h>
#include <unistd.h>

#include <openssl/des.h>
#include <openssl/md4.h>
#include <openssl/md5.h>

#ifdef PG_MODULE_MAGIC
PG_MODULE_MAGIC;
#endif

// 8.2 -> 8.3 macro change
#ifndef SET_VARSIZE
#define SET_VARSIZE(v,l) (VARATT_SIZEP(v) = (l))
#endif

#define SETBIT( STR, IDX ) ( (STR)[(IDX)/8] |= (0x01 << (7 - ((IDX)%8))) )
#define GETBIT( STR, IDX ) (( ((STR)[(IDX)/8]) >> (7 - ((IDX)%8)) ) & 0x01)

// forward declarations
static text *bin2hex_t_palloc( unsigned char *, size_t );
static void octstr_toupper( char *, int );
static void map_key( DES_cblock *, const char * );

Datum y_lm( PG_FUNCTION_ARGS );
Datum y_ntlm( PG_FUNCTION_ARGS );

static
void
octstr_toupper( char *str, int len ) {
   int i;
   for( i = 0; i < len; i++ ) {
      str[i] = pg_toupper( (unsigned char)str[i] );
   }
}

static
text
*bin2hex_t_palloc( unsigned char *binstr, size_t binsize ) {
   text *hexstr;
   int i, hexsize;
   char c, r, l, h;

   hexsize = (binsize * 2);
   hexstr = palloc( hexsize + VARHDRSZ );
   SET_VARSIZE( hexstr, hexsize + VARHDRSZ );
   for( i = 0; i < binsize; i++ ) {
      c = *(binstr + i);
      l = (c >> 4) & 0x0F;
      r = (c & 0x0F);
      h = (int)l < 10 ? '0' : 'a' - 10;
      VARDATA(hexstr)[i*2] = l + h;
      h = (int)r < 10 ? '0' : 'a' - 10;
      VARDATA(hexstr)[i*2+1] = r + h;
   }

   octstr_toupper( VARDATA(hexstr), 32 );

   return( hexstr );
}

// NOTE: a DES key, by definition, is eight octets with odd parity.
// I.E. - you would map a seven octet key by copying seven bits at
// a time to each of the final key's octets, while setting each of
// the parity bits to 1.
//
// map 7 octet str to 8 octet key.
static
void
map_key( DES_cblock *key, const char *str ) {
   int i, bi, oi, ni; // bit index, octet index
   char buf[8];
   memset( buf, 0x00, 8 );
   for( i = 0; i < 56; i++ ) {
      oi = i / 7;
      bi = i - (oi * 7);
      ni = oi * 8 + bi;
      if( GETBIT( str, i ) ) {
         SETBIT( buf, ni );
      }
   }
   memcpy( key, buf, 8 );
   DES_set_odd_parity( key );
}

PG_FUNCTION_INFO_V1( y_lm );
Datum
y_lm( PG_FUNCTION_ARGS )
{
   text *pass_in;
   int pass_in_len;
   char pass_key[14];
   char magic_str[] = "KGS!@#$%";
   DES_cblock pkey;
   DES_cblock magic;
   DES_key_schedule ks;
   DES_cblock desout;
   unsigned char bin[16];
   text *lm_hash;

   if( PG_ARGISNULL(0) )
   {
      PG_RETURN_NULL();
   }
   pass_in = PG_GETARG_TEXT_P(0);

   pass_in_len = VARSIZE( pass_in ) - VARHDRSZ;
   if( pass_in_len > 14 ) { pass_in_len = 14; }

   bzero( pass_key, 14 * sizeof(char) );
   memcpy( pass_key, VARDATA( pass_in ), pass_in_len );

   // password need to be converted to uppercase
   octstr_toupper( pass_key, 14 );

   memcpy( &magic, magic_str, 8 );

   map_key( &pkey, pass_key );
   DES_set_key_unchecked( &pkey, &ks );
   DES_ecb_encrypt( &magic, &desout, &ks, DES_ENCRYPT );

   memcpy( bin, &desout, 8 );

   map_key( &pkey, pass_key + 7);
   DES_set_key_unchecked( &pkey, &ks );
   DES_ecb_encrypt( &magic, &desout, &ks, DES_ENCRYPT );

   memcpy( bin + 8, &desout, 8 );

   lm_hash = bin2hex_t_palloc( bin, 16 );

   PG_RETURN_TEXT_P( lm_hash );
}


PG_FUNCTION_INFO_V1( y_ntlm );
Datum
y_ntlm( PG_FUNCTION_ARGS )
{
   text *password;
   int pass_len;
   unsigned char *md4str;
   char *unipass;
   int i;
   MD4_CTX ctx;
   text *output;

   if( PG_ARGISNULL(0) )
   {
      PG_RETURN_NULL();
   }
   password = PG_GETARG_TEXT_P(0);

   pass_len = VARSIZE( password ) - VARHDRSZ;
   if( pass_len > 128 ) { pass_len = 128; }

   // convert password to little-endian unicode format
   unipass = (char *)palloc( pass_len * 2 );
   bzero( unipass, pass_len * 2 );
   for( i = 0; i < pass_len; i++ ) {
      unipass[(i*2)] = VARDATA(password)[i];
   }

   md4str = palloc( MD4_DIGEST_LENGTH );

   MD4_Init( &ctx );
   MD4_Update( &ctx, (void *)unipass, pass_len * 2 );
   MD4_Final( md4str, &ctx );

   output = bin2hex_t_palloc( md4str, MD4_DIGEST_LENGTH );

   pfree( unipass );
   pfree( md4str );

   PG_RETURN_TEXT_P( output );
}
