#include "Decoder.h"

namespace dparser {
	int Decoder::T;
	int Decoder::pos_id_dummy;

	void Decoder::viterbi(const Instance *inst, const bool constrained)
	{
		const int length = inst->size();

		_chart.resize(length+1, T); // length+1: one position for the end of sentence (EOS)
		_chart = NULL;
		_chart[0][pos_id_dummy] = new ChartItem(0, pos_id_dummy);
		for (int i = 1; i <= length; ++i) {
			for (int t = 0; t < T; ++t) {
				if (i == length && t != pos_id_dummy) continue;
				if (constrained && i < length &&  !inst->constrained_tags[i][t]) continue;
				for (int tL1 = 0; tL1 < T; ++tL1) {
					const ChartItem * const trace = _chart[i-1][tL1];
					if (!trace) continue;
					double prob = trace->_prob;
					list<const fvec *> fvs;
					{
						if (!inst->prob_unigram_joint.empty()) {
							prob += inst->prob_unigram_joint[i][t];
							fvs.push_back(&inst->fvec_unigram_joint[i][t]);
						}
						if (!inst->prob_bigram_joint.empty()) {
							prob += inst->prob_bigram_joint[i][tL1][t];
							fvs.push_back(&inst->fvec_bigram_joint[i][tL1][t]);
						}
					}
					{
						const int tA = joint_id_2_a[t];
						const int tAL1 = joint_id_2_a[tL1];
						if (!inst->prob_unigram_a.empty()) {
							prob += inst->prob_unigram_a[i][tA];
							fvs.push_back(&inst->fvec_unigram_a[i][tA]);
						}
						if (!inst->prob_bigram_a.empty()) {
							prob += inst->prob_bigram_a[i][tAL1][tA];
							fvs.push_back(&inst->fvec_bigram_a[i][tAL1][tA]);
						}
					}
					{
						const int tB = joint_id_2_b[t];
						const int tBL1 = joint_id_2_b[tL1];
						if (!inst->prob_unigram_b.empty()) {
							prob += inst->prob_unigram_b[i][tB];
							fvs.push_back(&inst->fvec_unigram_b[i][tB]);
						}
						if (!inst->prob_bigram_b.empty()) {
							prob += inst->prob_bigram_b[i][tBL1][tB];
							fvs.push_back(&inst->fvec_bigram_b[i][tBL1][tB]);
						}
					}
					
					const ChartItem * const item = new ChartItem(i, t, prob, fvs, trace);
					add_item(_chart[i][t], item);
				}
			}
		}
	}

	void Decoder::forward( const Instance * const inst, const bool constrained )
	{
		const int length = inst->size();
		_forward_chart.resize(length+1, T);
		_forward_chart = DOUBLE_NEGATIVE_INFINITY;
		_forward_chart[0][pos_id_dummy] = LOG_EXP_ZERO;
		for (int i = 1; i <= length; ++i) {
			for (int t = 0; t < T; ++t) {
				if (i == length && t != pos_id_dummy) continue;
				if (constrained && i < length &&  !inst->constrained_tags[i][t]) continue;
				double prob_uni = 0;
				if (!inst->prob_unigram_joint.empty()) {
					const double p = inst->prob_unigram_joint[i][t];
					assert(!equal_to_negative_infinite(p));
					prob_uni += p;
				}
				const int tA = joint_id_2_a[t];
				const int tB = joint_id_2_b[t];
				if (!inst->prob_unigram_a.empty()) {
					const double p = inst->prob_unigram_a[i][tA];
					assert(!equal_to_negative_infinite(p));
					prob_uni += p;
				}
				if (!inst->prob_unigram_b.empty()) {
					const double p = inst->prob_unigram_b[i][tB];
					assert(!equal_to_negative_infinite(p));
					prob_uni += p;
				}

				double log_sum = DOUBLE_NEGATIVE_INFINITY;
				for (int tL1 = 0; tL1 < T; ++tL1) {
					const double a = _forward_chart[i-1][tL1];
					double prob_bi = 0;
					if (!inst->prob_bigram_joint.empty()) {
						const double p = inst->prob_bigram_joint[i][tL1][t];
						assert(!equal_to_negative_infinite(p));
						prob_bi += p;
					}
					const int tAL1 = joint_id_2_a[tL1];
					const int tBL1 = joint_id_2_b[tL1];
					if (!inst->prob_bigram_a.empty()) {
						const double p = inst->prob_bigram_a[i][tAL1][tA];
						assert(!equal_to_negative_infinite(p));
						prob_bi += p;
					}
					if (!inst->prob_bigram_b.empty()) {
						const double p = inst->prob_bigram_b[i][tBL1][tB];
						assert(!equal_to_negative_infinite(p));
						prob_bi += p;
					}

					log_add_if_not_negative_infinite(log_sum, a, prob_bi);
				}
				if (!equal_to_negative_infinite(log_sum)) {
					_forward_chart[i][t] = log_sum + prob_uni;
				}
			}
		}
	}

	void Decoder::backward( const Instance * const inst, const bool constrained )
	{
		const int length = inst->size();
		_backward_chart.resize(length+1, T);
		_backward_chart = DOUBLE_NEGATIVE_INFINITY;
		_backward_chart[length][pos_id_dummy] = LOG_EXP_ZERO;
		for (int i = length-1; i >= 0; --i) {
			for (int t = 0; t < T; ++t) {
				if (i == 0 && t != pos_id_dummy) continue;
				if (constrained && i < length &&  !inst->constrained_tags[i][t]) continue;

				double log_sum = DOUBLE_NEGATIVE_INFINITY;
				for (int tR1 = 0; tR1 < T; ++tR1) {
					const double a = _backward_chart[i+1][tR1];
					double prob = 0;
					if (!inst->prob_unigram_joint.empty()) {
						const double p = inst->prob_unigram_joint[i+1][tR1];
						assert(!equal_to_negative_infinite(p));
						prob += p;
					}
					const int tAR1 = joint_id_2_a[tR1];
					const int tBR1 = joint_id_2_b[tR1];
					if (!inst->prob_unigram_a.empty()) {
						const double p = inst->prob_unigram_a[i+1][tAR1];
						assert(!equal_to_negative_infinite(p));
						prob += p;
					}
					if (!inst->prob_unigram_b.empty()) {
						const double p = inst->prob_unigram_b[i+1][tBR1];
						assert(!equal_to_negative_infinite(p));
						prob += p;
					}

					const int tA = joint_id_2_a[t];
					const int tB = joint_id_2_b[t];
					if (!inst->prob_bigram_joint.empty()) {
						const double p = inst->prob_bigram_joint[i+1][t][tR1];
						assert(!equal_to_negative_infinite(p));
						prob += p;
					}
					if (!inst->prob_bigram_a.empty()) {
						const double p = inst->prob_bigram_a[i+1][tA][tAR1];
						assert(!equal_to_negative_infinite(p));
						prob += p;
					}
					if (!inst->prob_bigram_b.empty()) {
						const double p = inst->prob_bigram_b[i+1][tB][tBR1];
						assert(!equal_to_negative_infinite(p));
						prob += p;
					}

					log_add_if_not_negative_infinite(log_sum, a, prob);
				}
				_backward_chart[i][t] = log_sum;
			}
		}
	}

	void Decoder::get_result( Instance *inst ) const
	{
		const int length = inst->size();
		inst->predicted_tagids.clear();
		inst->predicted_tagids.resize(length, -1);
		inst->predicted_fv.clear();
		inst->predicted_prob = 0;

		const ChartItem * best_item = _chart[length][pos_id_dummy];
		inst->predicted_prob = best_item->_prob;
		get_best_parse_recursively(inst, best_item);
	}

	void Decoder::get_best_parse_recursively( Instance *inst, const ChartItem * const item ) const
	{
		if (!item) return;
		if (item->_i < inst->size()) {
			inst->predicted_tags_joint[item->_i] = item->_pos_id;
		}

		// collect features
		for (list<const fvec *>::const_iterator it = item->_fvs.begin(); it != item->_fvs.end(); ++it) {
			parameters::sparse_add(inst->predicted_fv, *it);
		}
		get_best_parse_recursively(inst, item->_trace);
	}

	void Decoder::use_marginal_as_arc_score( Instance * const inst )
	{
		// Must compute marginals-for-bigram-features first, because it needs to use prob_unigrams
		const int len = inst->size();
		inst->prob_bigram_joint = 0;
		inst->prob_bigram_a = 0;
		inst->prob_bigram_b = 0;
		/*inst->prob_bigram = DOUBLE_NEGATIVE_INFINITY;
		for (int i = 1; i <= len; ++i) {
			for (int t = 0; t < T; ++t) {
				if (i == len && t != pos_id_dummy) continue;
				for (int tL1 = 0; tL1 < T; ++tL1) {
					inst->prob_bigram[i][tL1][t] = marginal_prob(inst, i, tL1, t);
				}
			}
		}*/

		// ? 0 && length
		inst->prob_unigram_a = 0;
		inst->prob_unigram_b = 0;
		inst->prob_unigram_joint = DOUBLE_NEGATIVE_INFINITY;
		inst->prob_unigram_joint[0][pos_id_dummy] = 0;
		inst->prob_unigram_joint[len][pos_id_dummy] = 0;
		for (int i = 1; i < len; ++i) {
			for (int t = 0; t < T; ++t) {
				const double prob = marginal_prob(inst, i, t);
				inst->prob_unigram_joint[i][t] = prob;
				//inst->prob_unigram_a[i][ joint_id_2_a[t] ] += prob;
				//inst->prob_unigram_b[i][ joint_id_2_b[t] ] += prob;
			}
		}
	}

	void Decoder::check_marginal_prob( const Instance * const inst )
	{
		bool error_occur = false;
		const int len = inst->size();
		for (int i = 1; i < len; ++i) {
			double prob = 0.;
			for (int t = 0; t < T; ++t) {
				const double tmp = marginal_prob(inst, i, t);
				prob += tmp;
				//cerr << tmp << " ";
			}
			//cerr << prob << endl;
			if (!coarse_equal_to(prob, 1.0)) {
				error_occur = true;	
				cerr.precision(15);
				cerr << "\\sum_{t}{prob(i,t} (m=" << i << ") : " << prob << endl;
			}

		}
		if (error_occur) {
			cerr << "\nlog_Z: " << log_Z(inst) << endl;
			exit(-1);
		}
	}




} // namespace dparser

