BERT对中文文本分类实践(基于cnews数据集)


最近在学习BERT,所以想用文本分类试验一下,本以为会很简单,但还是遇到不少问题。
参考的文章有:
这篇文章有一个小坑
这篇是正解

数据集
链接:https://pan.baidu.com/s/1LzTidW_LrdYMokN—Nyag
提取码:zejw

数据格式如下
在这里插入图片描述
从https://github.com/google-research/bert上克隆项目。

下载BERT的中文预训练模型:
链接:https://pan.baidu.com/s/14JcQXIBSaWyY7bRWdJW7yg
提取码:mvtl

我的项目结构如下图:
在这里插入图片描述
cnews:cnews的数据
cnews_model: 训练后的模型
cnews_output: 预测的结果
fenlei.py: 由于预测的结果输出为概率,所以该文件用于将概率转化为标签,代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import pandas as pd
data_dir = "/home/zcsc/PycharmProjects/cnews/cnews_output/test_results.tsv"

lable = ["体育", "娱乐", "家居", "房产", "教育", "时尚", "时政", "游戏", "科技", "财经"]

# 用pandas读取test_result.tsv,将标签设置为列名

data_df = pd.read_table(data_dir, sep="\t", names=lable, encoding="utf-8")

label_test = []
true_label = []
for i in range(data_df.shape[0]):
#获取一行中最大值对应的列名,追加到列表
  label_test.append(data_df.loc[i, :].idxmax())

with open("/home/zcsc/PycharmProjects/cnews/cnews/cnews.test.txt") as f:
  ture_texts = f.readlines()

在开始之前要知道BERT模型的输入输出格式,建议看下这篇文章:
https://www.pianshen.com/article/9173351154/

那么正式开始,该项目需要改动的源码部分只有run_classifier.py

一、修改run_classifier.py,添加自定义的Processor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class MyProcessor(DataProcessor):
 
    def read_txt(self, data_dir, flag):
        with open(data_dir, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        random.seed(0)
        random.shuffle(lines)
        # 取少量数据做训练
        if flag == "train":
            lines = lines[0:5000]
        elif flag == "dev":
            lines = lines[0:500]
        elif flag == "test":
            lines = lines[0:100]
        return lines
 
    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_txt(os.path.join(data_dir, "cnews.train.txt"), "train"), "train")
 
    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_txt(os.path.join(data_dir, "cnews.val.txt"), "dev"), "dev")
 
    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_txt(os.path.join(data_dir, "cnews.test.txt"), "test"), "test")
 
    def get_labels(self):
        """See base class."""
        return ["体育", "娱乐", "家居", "房产", "教育", "时尚", "时政", "游戏", "科技", "财经"]
 
    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            split_line = line.strip().split("\t")
            text_a = tokenization.convert_to_unicode(split_line[1])
            text_b = None
            if set_type == "test":
                label = "体育"
            else:
                label = tokenization.convert_to_unicode(split_line[0])
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

这里有一个地方,就是在自定义的read_txt()函数里将train、test以及、val都进行了shuffle,这也是第一篇参考文章中出现的问题,如果不进行打乱,那么eval_loss会是0.1(1/种类),而且不仅是训练集,验证集以及测试集也需要打乱。

二、修改main方法

1
2
3
4
5
6
7
8
9
10
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "myprocess": MyProcessor,
  }

三、运行程序

在终端运行

1
python run_cnews_classifier.py --task_name=myprocess --do_train=true --do_eval=true --do_predict=false --data_dir=/home/zcsc/PycharmProjects/cnews/cnews --vocab_file=/home/zcsc/PycharmProjects/cnews/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=/home/zcsc/PycharmProjects/cnews/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=/home/zcsc/PycharmProjects/cnews/chinese_L-12_H-768_A-12/bert_model.ckpt --train_batch_size=32 --max_seq_length=128 --output_dir=/home/zcsc/PycharmProjects/cnews/cnews_model

得到的结果为:
INFO:tensorflow: eval_accuracy = 0.93386775
INFO:tensorflow: eval_loss = 0.33081177
INFO:tensorflow: global_step = 468
INFO:tensorflow: loss = 0.3427003

可以得到fine-tune的模型,测试的命令如下:

1
python run_classifier.py --task_name=myprocess --do_train=false --do_eval=false --do_predict=true --data_dir=/home/zcsc/PycharmProjects/cnews/cnews --vocab_file=/home/zcsc/PycharmProjects/cnews/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=/home/zcsc/PycharmProjects/cnews/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=/home/zcsc/PycharmProjects/cnews/cnews_model/model.ckpt-468  --max_seq_length=128 --output_dir=/home/zcsc/PycharmProjects/cnews/cnews_output

结果在output_dir中,运行fenlei.py可以得到标签。