#!/usr/bin/python
# encoding: utf-8


import numpy as np
import os


class sentence:
    def __init__(self):
        self.idx = []
        self.chars = []
        self.bies = []


class data:
    def __init__(self):
        self.sentences = []
        self.fileName = ""
        self.file = ""

    def OpenFile(self, fileName):
        self.fileName = fileName
        self.file = open(self.fileName, 'r')

    def ReadData(self):
        sen = sentence()
        for line in self.file:
            if line == '\n' or line == '\r\n':
                self.sentences.append(sen)
                sen = sentence()
                continue
            lines = line.split('\t')
            idx = lines[0]
            char = lines[1]
            bies = lines[3][0]
            if bies == 'M':
                bies = 'I'
            sen.idx.append(idx)
            sen.chars.append(char)
            sen.bies.append(bies)

    def CloseFile(self):
        self.file.close()

BIES = {'B': 0, 'I': 1, 'E': 2, 'S': 3}
Int_BIES = {0: 'B', 1: 'I', 2: 'E', 3: 'S'}
Legal_bies = {"B+I": 0, "B+E": 0, "I+I": 0, "I+E": 0, "E+B": 0, "E+S": 0, "S+B": 0, "S+S": 0}
reversed_Legal_bies = {"I+B": 0, "E+B": 0, "I+I": 0, "E+I": 0, "B+E": 0, "S+E": 0, "B+S": 0, "S+S": 0}


def viterbiMerge(outFile, files):
    print "viterMerge..."
    file_count = len(files)
    sentence_count = len(files[0].sentences)
    sentences_list = [None] * sentence_count
    for i in xrange(sentence_count):
        sentences_list[i] = [None] * file_count
    for i in xrange(file_count):
        file = files[i]
        for j in xrange(sentence_count):
            sentence = file.sentences[j]
            sentences_list[j][i] = sentence

    for sentences in sentences_list:
        sentence_length = len(sentences[0].idx)
        chart_sen_bies_num = np.zeros((4, sentence_length))
        # 统计一个句子的每一个字的bies个数
        for sentence in sentences:
            for i in xrange(sentence_length):
                bies = sentence.bies[i]
                chart_sen_bies_num[BIES[bies]][i] += 1
        # 根据viterbi算法求出句子bies的最优序列
        sen_score = []
        last_score_i, score_i = [0] * 4, [0] * 4
        sen_last_bies, sen_current_bies = [], []
        last_bies_i, bies_i = [''] * 4, [''] * 4
        # 句子的第一个字
        for j in xrange(4):
            bies_i[j] = Int_BIES[j]
            if Int_BIES[j] == 'I' or Int_BIES[j] == 'E':
                continue
            score_i[j] = chart_sen_bies_num[j][0]
        last_score_i = score_i
        sen_score.append(score_i)  # 第一个字done!
        for i in xrange(1, sentence_length):
            last_bies_i, bies_i = [''] * 4, [''] * 4
            score_i = [0] * 4
            for j in xrange(4):  # 当前字的四种bies
                score = [0] * 4
                current_bies = Int_BIES[j]
                for k in xrange(4):  # 上一个字的四中bies
                    last_bies = Int_BIES[k]
                    joint_bies = last_bies + '+' + current_bies
                    if joint_bies not in Legal_bies:
                        continue
                    score_kj = last_score_i[k] + chart_sen_bies_num[j][i]
                    score[k] = score_kj
                score_i[j] = max(score)
                bies_i[j] = Int_BIES[j]
                last_bies_i[j] = Int_BIES[score.index(max(score))]
            sen_last_bies.append(last_bies_i)
            last_score_i = score_i
            sen_score.append(score_i)
            sen_current_bies.append(bies_i)
        # 解码
        last_bies_index = score_i.index(max(score_i))
        if last_bies_index == 0 or last_bies_index == 1:  # 句子的末尾bies不能是B和I
            score_temp = []
            for i in xrange(len(score_i)):
                score_temp.append((score_i[i], i))
            score_i = sorted(score_temp, key=lambda score_t: score_t[0], reverse=True)  # 对结尾的Bies分值排序
            for i in xrange(len(score_i)):
                last_bies_index = score_i[i][1]
                if last_bies_index == 2 or last_bies_index == 3:  # 得分从高到底，碰到E或者S就可以
                    break
        bies = Int_BIES[last_bies_index]
        if bies == 'B' or bies == 'I':
            print "last error"  # 经过上面的排序，这边的情况没有出现
        sentence_bies = []
        sentence_bies.append(bies)
        sen_current_bies.reverse()
        sen_last_bies.reverse()
        sen_score.reverse()
        last_bies = bies
        # print sentence_length, len(sen_current_bies)
        for i in xrange(sentence_length - 1):
            bies_i = sen_current_bies[i]
            last_bies_index = bies_i.index(bies)
            bies = sen_last_bies[i][last_bies_index]
            joint_bies = last_bies + '+' + bies
            if joint_bies not in reversed_Legal_bies:  # 如果转移不合法
                # 不合法的相关处理，应该不会出现这种情况
                print "error", joint_bies
            sentence_bies.append(bies)
            last_bies = bies
        sentence_bies.reverse()
        # print sentence_length, len(sentence_bies)
        # print sentence_bies
        for i in xrange(sentence_length):
            idx = sentences[0].idx[i]
            char = sentences[0].chars[i]
            bies = sentence_bies[i]
            outFile.write(idx + '\t' + char + '\t' + '_' + '\t' + bies)
            outFile.write(('\t' + '_') * 6 + '\n')
        outFile.write('\n')
    return


if __name__ == "__main__":
    # filefold = sys.argv[1]
    filefold = "./weibo_devs/"
    # filefold = "./weibo_devs_submitted/"
    files = []
    for filename in os.listdir(filefold):
        fileName = filefold + filename
        file = data()
        file.OpenFile(fileName)
        file.ReadData()
        file.CloseFile()
        files.append(file)
    # viterbi合并
    outFile = open("dev.ultimate.conll", 'w')
    viterbiMerge(outFile, files)
    outFile.close()
