Bert源码解读(三)之预训练部分
时间:2020-03-02 19:44:57
收藏:0
阅读:277
一、Masked LM
get_masked_lm_output
函数用于计算「任务#1」的训练 loss。输入为 BertModel 的最后一层 sequence_output 输出([batch_size, seq_length, hidden_size]),先找出输出结果中masked掉的词,然后构建一层全连接网络,接着构建一层节点数为vocab_size的softmax输出,从而与真实label计算损失。
def get_masked_lm_output(bert_config, input_tensor, #BertModel的最后一层sequence_output输出model.get_sequence_output(), [batch_size, seq_length, hidden_size] output_weights,#输入是model.get_embedding_table(),[vocab_size,hidden_size] positions, #mask词的位置 label_ids, #label,真实值结果 label_weights): """Get loss and log probs for the masked LM.""" # 根据positions位置获取masked词在Transformer的输出结果,即要预测的那些位置的encoder input_tensor = gather_indexes(input_tensor, positions) with tf.variable_scope("cls/predictions"): # 在输出之前添加一个带激活函数的全连接神经网络,只在预训练阶段起作用 with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=modeling.get_activation(bert_config.hidden_act), kernel_initializer=modeling.create_initializer( bert_config.initializer_range)) input_tensor = modeling.layer_norm(input_tensor) # output_weights是和传入的word embedding一样的,这里再添加一个bias output_bias = tf.get_variable( "output_bias", shape=[bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size,max_pred_pre_seq,vocab_size] logits = tf.nn.bias_add(logits, output_bias) #得出masked词的softmax结果,[batch_size,max_pred_pre_seq,vocab_size] log_probs = tf.nn.log_softmax(logits, axis=-1) # label_ids表示mask掉的Token的id,下面这部分就是根据真实值计算loss了。 label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) one_hot_labels = tf.one_hot( label_ids, depth=bert_config.vocab_size, dtype=tf.float32) # 但是由于实际MASK的可能不到20,比如只MASK18,那么label_ids有2个0(padding),而label_weights=[1, 1, ...., 0, 0],说明后面两个label_id是padding的,计算loss要去掉,label_weights就是起一个标记作用 per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, log_probs)
二、 Next Sentence Prediction
get_next_sentence_output函数用于计算「任务#2」的训练 loss,这部分比较简单,只需要再额外加一层softmax输出即可。输入为 BertModel 的最后一层 pooled_output 输出([batch_size, hidden_size]),因为该任务属于二分类问题,所以只需要每个序列的第一个 token【CLS】即可。
def get_next_sentence_output(bert_config, input_tensor,#pooled_output 输出,shape=[batch_size, hidden_size] labels): """Get loss and log probs for the next sentence prediction.""" # 标签0表示 下一个句子关系成立;标签1表示 下一个句子关系不成立。这个分类器的参数在实际Fine-tuning阶段会丢弃掉 with tf.variable_scope("cls/seq_relationship"): #初始化权重参数,最终的分类结果是只有2个,所以shape=[2,hidden_size] output_weights = tf.get_variable( "output_weights", shape=[2, bert_config.hidden_size], initializer=modeling.create_initializer(bert_config.initializer_range)) output_bias = tf.get_variable( "output_bias", shape=[2], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True)#输入与权重相乘,shape=[batch_size,2] logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1)#softmax输出:shape=[batch_size,2] #下面这部分就是根据真实值计算损失loss了 labels = tf.reshape(labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, log_probs)
原文:https://www.cnblogs.com/gczr/p/12396992.html
评论(0)