#ifndef _PARSER_
#define _PARSER_

#pragma once
#include <vector>
#include <iostream>
#include <fstream>
#include <iomanip>
#include <set>
using namespace std;

#include "IOPipe.h"
#include "FGen.h"
#include "Decoder.h"

#include "Parameters.h"
#include "common.h"
#include "GzFile.h"

/*******************
There seems some conflicts between "ChartUtils.h" and "spthread.h".
The order of their #include can not be reversed!
I do not know why!
 *******************/
#include "CharUtils.h"
#include "StringMap.h"
using namespace egstra;

#include "spthread.h"
#include "threadpool.h"
/*******************/

/* LBFGS training
The codes are based on crfsuite, sgd-2.1, crf++
Zhenghua Li
2013.09.06
*/

namespace dparser {
	/*
	this class controls the parsing process.
	*/
	class Parser
	{
	public:
		IOPipe m_pipe_train;
		IOPipe m_pipe_train2;
		IOPipe m_pipe_test;
		IOPipe m_pipe_dev;
		FGen m_fgen;
		parameters m_param;
		Decoder *m_decoder;

/* options */
	private:
		int _display_interval;

		string _dictionary_path;
		string _parameter_path;
		int _inst_max_len_to_throw;

		bool _train;
		int _iter_num;
        int _best_iter_num_so_far;
        double _best_accuracy;

		vector<int> _inst_idx_to_read;
		
		bool _self_training;
		bool _use_train2;
		int _inst_num_from_train2_one_iter;
		int _inst_num_from_train1_one_iter;
		string _filename_train2;
		int _inst_max_num_train2;

		string _filename_train;
		string _filename_dev;

		int _inst_max_num_train;
		bool _dictionary_exist;
		bool _pamameter_exist;
		int _param_tmp_num;
        int _use_best_parameter_num;

		bool _test;
		string _filename_test;
		string _filename_output;
		int _param_num_for_eval;
		int _inst_max_num_eval;
		int _test_batch_size;

		bool _verify_decoding_algorithm;
		
		int _thread_num;

		/* thread control */
		threadpool _tp;
		static sp_thread_mutex_t _mutex;
		static sp_thread_cond_t _cond_waiting_create_feat;	// waiting for the decoding-thread
		static sp_thread_cond_t _cond_waiting_update;	// waiting for the creating-features-thread to finish the current instance.
		static sp_thread_cond_t _cond_done_update;		// complete all the instances

		static vector<bool> _train_features_created;
		static int _train_create_feat_inst_i;
		static int _train_update_inst_i;

		static double _sum_loss, _t0, _t, _lambda, _eta, _decay, _gain;
		static vector<double> _g;
		static bool _mbr_decoding;
		static bool _test_tag_filter;
		static double _test_tag_filter_lambda;

		ofstream _of_tag_filter_prob;
		int tot_word_num;
		int tot_tag_num;
		int tot_correct_tag_num;
		void initialize_filter_stat() {
			tot_word_num = 0;
			tot_tag_num = 0;
			tot_correct_tag_num = 0;
		}
		void output_filter_stat()
		{
			cerr << " min-risk tag filter results: " << endl;
			fprintf(stderr, "oracle POS tagging accuracy: %d / %d = %.2f\n", 
				tot_correct_tag_num, tot_word_num, 100.0*tot_correct_tag_num/tot_word_num);
			fprintf(stderr, "average tag num per word: %d / %d = %.2f \n", 
				tot_tag_num, tot_word_num, 1.0*tot_tag_num/tot_word_num);
		}
		void evaluate_output_tag_filter(Instance *inst) {
			const int len = inst->size();
			tot_word_num += (len - 1);
			for (int wi = 1; wi < len; ++wi) {
				tot_tag_num += inst->filtered_tags[wi].size();
				ostringstream os_head_list;
				ostringstream os_prob_list;
				//os_prob_list.precision(15);
				for (int ti = 0; ti < inst->filtered_tags[wi].size(); ++ti) {
					os_head_list << (ti == 0 ? "" : "_")
						<< inst->filtered_tags[wi][ti];
					os_prob_list << (ti == 0 ? "" : "_")
						<< inst->prob_filtered_tags[wi][ti];
					if (inst->filtered_tags[wi][ti] == inst->cpostags[wi]) {
						++tot_correct_tag_num;
					}
				}
				_of_tag_filter_prob << os_prob_list.str() << endl;
				inst->pdeprels[wi] = os_head_list.str();
			}
			_of_tag_filter_prob << endl;
		}

		/* variables used in evaluate */
		int inst_num_processed_total;
		int oov_num;
		int oov_pos_correct_num;
		int word_punc_num_pos_correct;
		int word_punc_num_total;

		string _train_method;

		typedef struct {
			int			batch_size;				
			floatval_t  c2;						// Coefficient for L2 regularization
			//int         max_iterations;			// The maximum number of iterations (epochs) for SGD optimization
			int         period;					// The duration of iterations to test the stopping criterion.
			floatval_t  delta;					/** The threshold for the stopping criterion; an optimization process stops when
												the improvement of the log likelihood over the last ${period} iterations is no
												greater than this threshold.*/
			floatval_t  calibration_eta;		// The initial value of learning rate (eta) used for calibration
			floatval_t  calibration_rate;		// The rate of increase/decrease of learning rate for calibration.
			int         calibration_samples;	// The number of instances used for calibration
			int         calibration_candidates;	// The number of candidates of learning rate.
			int         calibration_max_trials;	// The maximum number of trials of learning rates for calibration

			floatval_t  lambda;					
			floatval_t  t0;	
		} l2sgd_training_option_t;
		l2sgd_training_option_t _l2sgd_opt;
		
	public:
		Parser() : m_decoder(0), _tp(0) {
			process_options();
			_tp = create_threadpool(max(1, _thread_num));
		}

		~Parser(void) {
			delete_decoder(m_decoder);
			destroy_threadpool(_tp);
			_tp = 0;
		}

		void process_options();

		void run()
		{
			if (_train) {
				pre_train();
				if (_train_method == "l2sgd")
					train_l2sgd();
				//else if (_train_method == "pa")
				//	train_passive_aggressive();
				else {
					cerr << "unknown train method: " << _train_method << endl;
					exit(-1);
				}
				post_train();
			}
			if (_test) test(_param_num_for_eval);
		}

		static Decoder *new_decoder() {
			Decoder *decoder = new Decoder();
			assert(decoder);
			decoder->process_options();
			return decoder;
		}

		static void delete_decoder(Decoder *&decoder) {
			if (decoder) {
				delete decoder;
				decoder = 0;
			}
		}

	private:
		typedef struct thread_arg_t {
			thread_arg_t(Parser * const parser, const int inst_num, Instance * const inst=0, const int inst_idx = -1, bool is_test=false)
				: _parser(parser), _inst_num(inst_num), _inst(inst), _inst_idx(inst_idx), _is_test(is_test) {}
			Parser * const _parser;
			const int _inst_num;
			Instance * const _inst;
			const int _inst_idx;
			const bool _is_test;
		} ;

		static void parse_one_inst_thread( void *arg );
		static void train_update_one_inst_thread( void *arg );

		double update_weights_or_gradients_with_gold_tree(const Instance *const inst, double * const g, const double gain) {
			/* add observed positive features */
			sparsevec sp_fv;
			m_fgen.create_all_pos_features_according_to_tree(inst, sp_fv, inst->cpostags);
			const double score = m_param.dot(sp_fv); // before update!

			sparsevec::const_iterator V_i = sp_fv.begin();
			const sparsevec::const_iterator V_end = sp_fv.end();
			for(; V_i != V_end; ++V_i) {
				const int id = V_i->first;
				const double val = V_i->second;
				assert(id < m_fgen.feature_dimentionality() && id >= 0);
				g[id] += val * gain;
			}
			return score;
		}

		void train_l2sgd();
		floatval_t l2sgd_calibration();

		// objection: min -(1/N) * \sum{logP(y|x)} + (\lambda / 2) * ||w||^2
		// sum_loss: -\sum{logP(y|x)} + C * ||w||^2   \lambda = 2C/N
		void l2sgd( const int N, 
			const floatval_t t0, 
			const floatval_t lambda, 
			const int num_epochs, 
			const bool calibration, 
			const int period, 
			const floatval_t epsilon);

		typedef struct {
			Parser *par;
		} lbfgs_internal_t;

		void parse(Decoder *decoder, Instance *inst, bool is_test) {
			const bool constrained = false; //(!inst->filtered_tags.empty());
			if (constrained) m_fgen.create_constrained_tag_matrix(inst);

			m_fgen.create_all_feature_vectors(inst);
			compute_all_probs(inst);
			if (_mbr_decoding) {
				decoder->compute_marginals(inst, false);
				decoder->use_marginal_as_arc_score(inst);
				if (is_test && _test_tag_filter) {
					filter_tag(inst);
					m_fgen.assign_filtered_tag_str(inst);
				}
			}

			decoder->decodeInterface(inst, constrained);
			m_fgen.assign_predicted_tag_str(inst);
			if(_verify_decoding_algorithm) verify_decoding_algorithm(inst);

			inst->predicted_fv.clear();
			m_fgen.dealloc_fvec_prob(inst);
		}
		
		void filter_tag(Instance *inst);
		
		void test(const int iter);

		Instance *get_instance(const int inst_idx) {
			const int real_inst_idx = _inst_idx_to_read[inst_idx];
			if (real_inst_idx < m_pipe_train.getInstanceNum()) 
				return m_pipe_train.getInstance(real_inst_idx);
			else
				return m_pipe_train2.getInstance(real_inst_idx - m_pipe_train.getInstanceNum());
		}

		void delete_one_train_instance_after_update_gradient(Instance *&inst) {
			if (inst->id < m_pipe_train.getInstanceNum()) {
				if (m_pipe_train.use_instances_posi()) {
					delete inst;
					inst = 0;
				}
			}
			else {
				if (m_pipe_train2.use_instances_posi()) {
					delete inst;
					inst = 0;
				}
			}
		}

		int get_inst_num_one_iter() const { return _inst_idx_to_read.size(); }
		void prepare_train_instances();
		void pre_train();
		void post_train() {
			m_pipe_train.dealloc_instance();
			m_pipe_train.closeInputFile();
			if (_use_train2) {
				m_pipe_train2.dealloc_instance();
				m_pipe_train2.closeInputFile();
			}
			m_pipe_dev.dealloc_instance();
		}
		void delete_candidate_heads(IOPipe &pipe) {
			for (int i = 0; i < pipe.getInstanceNum(); ++i) {
				Instance *inst = pipe.getInstance(i);
				if (!inst->constrained_tags.empty()) inst->constrained_tags.dealloc();
			}
		}

		void evaluate(IOPipe &pipe, const bool is_test);
		void reset_evaluate_metrics();
		void output_evaluate_metrics();

		void create_dictionaries(IOPipe &pipe, const bool collect_word=true);

		void load_dictionaries() {
			m_fgen.load_dictionaries(_dictionary_path);
			Decoder::T = m_fgen.tag_number();
			Decoder::pos_id_dummy = m_fgen.pos_id_dummy();
		}

		void save_dictionaries() {
			m_fgen.save_dictionaries(_dictionary_path);
		}

		void save_parameters(const int iter) {
			m_param.save(_parameter_path, iter);
		}

		void load_parameters(const int iter) {
			m_param.load(_parameter_path, iter);
		}

		void delete_parameters(const int iter) {
			m_param.delete_file(_parameter_path, iter);
		}

		void dot_all(const fvec * const fs, double * const probs, const int sz) const;

		void compute_all_probs(Instance *inst) const;

		void verify_decoding_algorithm( Instance * const inst);

		void evaluate_one_instance(const Instance * const inst);

		static void update_gradient(floatval_t *g, const fvec &fv, const double marg, const int n);
		static void update_gradient_one_inst(Parser *par, Decoder *decoder, const Instance *inst, double *g, const double gain);

		void eval_oov_pos( const Instance *inst, int &oov_num, int &oov_pos_correct_num )
		{
			for (int i = 1; i < inst->size(); ++i)
				if ( m_fgen.get_word_id(inst->forms[i]) < 0 ) {
					++oov_num;
					if (inst->cpostags[i] == inst->predicted_postags[i])
						++oov_pos_correct_num;
				}
		}

		int error_num_pos( const Instance *inst ) const
		{
			int error_num = 0;
			for (int i = 1; i < inst->size(); ++i)
				if (inst->cpostags[i] != inst->predicted_postags[i])
					++error_num;
			return error_num;
		}
	};
}


#endif

