#include "Decoder.h"

namespace dparser {
	int Decoder::T;
	int Decoder::pos_id_dummy;

	void Decoder::viterbi(const Instance *inst, const bool constrained)
	{
		const int length = inst->size();

		_chart.resize(length+1, T); // length+1: one position for the end of sentence (EOS)
		_chart = NULL;
		_chart[0][pos_id_dummy] = new ChartItem(0, pos_id_dummy);
		for (int i = 1; i <= length; ++i) {
			for (int t = 0; t < T; ++t) {
				if (i == length && t != pos_id_dummy) continue;
				if (constrained && i < length &&  !inst->constrained_tags[i][t]) continue;
				for (int tL1 = 0; tL1 < T; ++tL1) {
					const ChartItem * const trace = _chart[i-1][tL1];
					if (!trace) continue;
					const double prob = trace->_prob + inst->prob_unigram[i][t] + inst->prob_bigram[i][tL1][t];
					list<const fvec *> fvs;
					fvs.push_back(&inst->fvec_unigram[i][t]);
					fvs.push_back(&inst->fvec_bigram[i][tL1][t]);
					const ChartItem * const item = new ChartItem(i, t, prob, fvs, trace);
					add_item(_chart[i][t], item);
				}
			}
		}
	}

	void Decoder::forward( const Instance * const inst, const bool constrained )
	{
		const int length = inst->size();
		_forward_chart.resize(length+1, T);
		_forward_chart = DOUBLE_NEGATIVE_INFINITY;
		_forward_chart[0][pos_id_dummy] = LOG_EXP_ZERO;
		for (int i = 1; i <= length; ++i) {
			for (int t = 0; t < T; ++t) {
				if (i == length && t != pos_id_dummy) continue;
				if (constrained && i < length &&  !inst->constrained_tags[i][t]) continue;
				double log_sum = DOUBLE_NEGATIVE_INFINITY;
				for (int tL1 = 0; tL1 < T; ++tL1) {
					const double a = _forward_chart[i-1][tL1];
					log_add_if_not_negative_infinite(log_sum, a, inst->prob_bigram[i][tL1][t]);
				}
				if (!equal_to_negative_infinite(log_sum)) {
					_forward_chart[i][t] = log_sum + inst->prob_unigram[i][t];
				}
			}
		}
	}

	void Decoder::backward( const Instance * const inst, const bool constrained )
	{
		const int length = inst->size();
		_backward_chart.resize(length+1, T);
		_backward_chart = DOUBLE_NEGATIVE_INFINITY;
		_backward_chart[length][pos_id_dummy] = LOG_EXP_ZERO;
		for (int i = length-1; i >= 0; --i) {
			for (int t = 0; t < T; ++t) {
				if (i == 0 && t != pos_id_dummy) continue;
				if (constrained && i < length &&  !inst->constrained_tags[i][t]) continue;

				double log_sum = DOUBLE_NEGATIVE_INFINITY;
				for (int tR1 = 0; tR1 < T; ++tR1) {
					const double a = _backward_chart[i+1][tR1];
					log_add_if_not_negative_infinite(log_sum, a, inst->prob_bigram[i+1][t][tR1] + inst->prob_unigram[i+1][tR1]);
				}
				_backward_chart[i][t] = log_sum;
			}
		}
	}

	void Decoder::get_result( Instance *inst ) const
	{
		const int length = inst->size();
		inst->predicted_pos_ids.clear();
		inst->predicted_pos_ids.resize(length, -1);
		inst->predicted_fv.clear();
		inst->predicted_prob = 0;

		const ChartItem * best_item = _chart[length][pos_id_dummy];
		inst->predicted_prob = best_item->_prob;
		get_best_parse_recursively(inst, best_item);
	}

	void Decoder::get_best_parse_recursively( Instance *inst, const ChartItem * const item ) const
	{
		if (!item) return;
		if (item->_i < inst->size()) {
			inst->predicted_pos_ids[item->_i] = item->_pos_id;
		}

		// collect features
		for (list<const fvec *>::const_iterator it = item->_fvs.begin(); it != item->_fvs.end(); ++it) {
			parameters::sparse_add(inst->predicted_fv, *it);
		}
		get_best_parse_recursively(inst, item->_trace);
	}

	void Decoder::use_marginal_as_arc_score( Instance * const inst )
	{
		// Must compute marginals-for-bigram-features first, because it needs to use prob_unigrams
		const int len = inst->size();
		inst->prob_bigram = 0;
		/*inst->prob_bigram = DOUBLE_NEGATIVE_INFINITY;
		for (int i = 1; i <= len; ++i) {
			for (int t = 0; t < T; ++t) {
				if (i == len && t != pos_id_dummy) continue;
				for (int tL1 = 0; tL1 < T; ++tL1) {
					inst->prob_bigram[i][tL1][t] = marginal_prob(inst, i, tL1, t);
				}
			}
		}*/

		// ? 0 && length
		inst->prob_unigram = DOUBLE_NEGATIVE_INFINITY;
		inst->prob_unigram[0][pos_id_dummy] = 0;
		inst->prob_unigram[len][pos_id_dummy] = 0;
		for (int i = 1; i < len; ++i) {
			for (int t = 0; t < T; ++t) {
				inst->prob_unigram[i][t] = marginal_prob(inst, i, t);
			}
		}
	}

	void Decoder::check_marginal_prob( const Instance * const inst )
	{
		bool error_occur = false;
		const int len = inst->size();
		for (int i = 1; i < len; ++i) {
			double prob = 0.;
			for (int t = 0; t < T; ++t) {
				const double tmp = marginal_prob(inst, i, t);
				prob += tmp;
				//cerr << tmp << " ";
			}
			//cerr << prob << endl;
			if (!coarse_equal_to(prob, 1.0)) {
				error_occur = true;	
				cerr.precision(15);
				cerr << "\\sum_{t}{prob(i,t} (m=" << i << ") : " << prob << endl;
			}

		}
		if (error_occur) {
			cerr << "\nlog_Z: " << log_Z(inst) << endl;
			exit(-1);
		}
	}




} // namespace dparser

