#include "straighten_hash_table.h"

int pt_insert(struct pt_hash_table* pt_hash_tb, struct packed_tableau* insert, const int data) {
	int collisions = 0;
	uint8_t t = 0;
	if(pt_hash_tb[0].original_key_length & 1) {
		t = insert[0].entries[pt_hash_tb[0].key_length-1] & 0xF;
		insert[0].entries[pt_hash_tb[0].key_length-1] = insert[0].entries[pt_hash_tb[0].key_length-1] & ~0xF;
	}

	uint32_t index =  Hash_YoshimitsuTRIAD(insert[0].entries, pt_hash_tb[0].key_length);
	index = index % pt_hash_tb[0].size;

	uint64_t val = ((uint64_t*)pt_hash_tb[0].pt_arr[index].entries)[0];
	while (val != 0) {
		collisions++;
		index++;
		index = index % pt_hash_tb[0].size;
		val = ((uint64_t*)pt_hash_tb[0].pt_arr[index].entries)[0];
	}

	memcpy(pt_hash_tb[0].pt_arr[index].entries, insert[0].entries, pt_hash_tb[0].key_length * sizeof(uint8_t));
	pt_hash_tb[0].data_arr[index] = data;

	if(pt_hash_tb[0].original_key_length & 1) {
		insert[0].entries[pt_hash_tb[0].key_length-1] = insert[0].entries[pt_hash_tb[0].key_length-1] + t;
	}
	return collisions;
}

int pt_insert_nocopy(struct pt_hash_table* pt_hash_tb, struct packed_tableau* insert, const int data) {
	int collisions = 0;
	uint8_t t = 0;
	if(pt_hash_tb[0].original_key_length & 1) {
		t = insert[0].entries[pt_hash_tb[0].key_length-1] & 0xF;
		insert[0].entries[pt_hash_tb[0].key_length-1] = insert[0].entries[pt_hash_tb[0].key_length-1] & ~0xF;
	}

	uint32_t index =  Hash_YoshimitsuTRIAD(insert[0].entries, pt_hash_tb[0].key_length);
	index = index % pt_hash_tb[0].size;

	while (pt_hash_tb[0].pt_arr[index].entries != NULL) {
		collisions++;
		index++;
		index = index % pt_hash_tb[0].size;
	}

	//memcpy(pt_hash_tb[0].pt_arr[index].entries, insert[0].entries, pt_hash_tb[0].key_length * sizeof(uint8_t));
	memcpy(&pt_hash_tb[0].pt_arr[index], insert, sizeof(struct packed_tableau));
	pt_hash_tb[0].data_arr[index] = data;

	if(pt_hash_tb[0].original_key_length & 1) {
		insert[0].entries[pt_hash_tb[0].key_length-1] = insert[0].entries[pt_hash_tb[0].key_length-1] + t;
	}
	return collisions;
}

extern inline uint32_t Hash_YoshimitsuTRIAD(const uint8_t *str, size_t wrdlen);

extern inline int pt_compar_anysize_odd(const uint8_t* lhs, const uint8_t* rhs, size_t size);

extern inline int pt_compar_anysize_even(const uint8_t* lhs, const uint8_t* rhs, size_t size);

extern inline int pt_compar_numboxes_gt32kll32(const uint8_t* lhs, const uint8_t* rhs, size_t size);

extern inline int pt_compar_numboxes_36klgt32(const uint8_t* lhs, const uint8_t* rhs, size_t size);

extern inline int pt_search(const struct pt_hash_table* pt_hash_tb, const struct packed_tableau* key);

uint32_t next_pow2(uint32_t num) {
	num--;
	num |= num >> 1;
	num |= num >> 2;
	num |= num >> 4;
	num |= num >> 8;
	num |= num >> 16;
	num++;
	return num;
}

uint32_t log2_pow2(const uint32_t num) {
	static const uint32_t b[] = {0xAAAAAAAA, 0xCCCCCCCC, 0xF0F0F0F0, 
	                               0xFF00FF00, 0xFFFF0000};
	register uint32_t r = (num & b[0]) != 0;
	for (int i = 4; i > 0; i--) {
	  r |= ((num & b[i]) != 0) << i;
	}
	return r;
}

struct pt_hash_table* pt_build_hashtable(struct packed_tableau* all_dictionary_tableau, const size_t num_dict_tableau, const uint32_t key_length, struct shape_data_c * s_data, uint32_t copy_data) {
	uint32_t size = next_pow2(num_dict_tableau);
	/*if (size <= 268435456) {
	size = size<<2;
	}
	else if (size <= 1073741824) {
	size = size<<1;
	}*/
	if (size <= 33554432) {
	size = size<<2;
	}
	else if (size <= 1073741824) {
	size = size<<1;
	}
	straighten_log(STRAIGHTEN_VVINFO, "Generating a dictionary tableau hash of size %d for shape %s.", size, s_data[0].shape_string);
	struct pt_hash_table * pt_hash_tb = (struct pt_hash_table*) calloc(1, sizeof(struct pt_hash_table));
	pt_hash_tb[0].size = size;
	pt_hash_tb[0].and_size = size-1;
	pt_hash_tb[0].original_key_length = key_length;
	pt_hash_tb[0].copy_mode = copy_data;
	pt_hash_tb[0].pt_arr = (struct packed_tableau*) calloc(size, sizeof(struct packed_tableau));
	if(copy_data) {
		uint8_t * tableau_entries = (uint8_t*) calloc(size * s_data[0].num_packed_boxes, sizeof(uint8_t));
		for(int tab = 0; tab < size; tab++) {
			pt_hash_tb[0].pt_arr[tab].entries = tableau_entries;
			tableau_entries = tableau_entries + s_data[0].num_packed_boxes;
		}
	}
	pt_hash_tb[0].data_arr = (uint32_t*) calloc(size, sizeof(uint32_t));

	//branchless round up to nearest even and then divide by two
	int kl = (key_length+1)&~1;
	kl = kl>>1;
	//printf("%d \n", kl);
	pt_hash_tb[0].key_length = kl;
    
	int collisions = 0;
	int c;
	int max_c=0;
	for(int monomial=0; monomial < num_dict_tableau; monomial++) {
		if(copy_data) {
			c = pt_insert(pt_hash_tb, all_dictionary_tableau+monomial, monomial);
		}
		else {
			c = pt_insert_nocopy(pt_hash_tb, all_dictionary_tableau+monomial, monomial);
		}
		collisions += c;
	    if (c > max_c) {max_c = c;}
	}

	if(pt_hash_tb[0].original_key_length & 1) {
		pt_hash_tb[0].compar = pt_compar_anysize_odd;
	} 
	else {
		pt_hash_tb[0].compar = pt_compar_anysize_even;
	}

	straighten_log(STRAIGHTEN_INFO, "Total number of %d collisions in the dictionary hash table. Maximum number of %d collisions.", collisions, max_c);
	return pt_hash_tb;
}