Bert源码解读(三)之预训练部分

时间:2020-03-02 19:44:57   收藏:0   阅读:277

一、Masked LM

get_masked_lm_output函数用于计算「任务#1」的训练 loss。输入为 BertModel 的最后一层 sequence_output 输出([batch_size, seq_length, hidden_size]),先找出输出结果中masked掉的词,然后构建一层全连接网络,接着构建一层节点数为vocab_size的softmax输出,从而与真实label计算损失。

def get_masked_lm_output(bert_config, 
                        input_tensor, #BertModel的最后一层sequence_output输出model.get_sequence_output(), [batch_size, seq_length, hidden_size]
                        output_weights,#输入是model.get_embedding_table(),[vocab_size,hidden_size]
                           positions, #mask词的位置
                         label_ids, #label,真实值结果
                         label_weights):
  """Get loss and log probs for the masked LM."""
  # 根据positions位置获取masked词在Transformer的输出结果,即要预测的那些位置的encoder
  input_tensor = gather_indexes(input_tensor, positions)

  with tf.variable_scope("cls/predictions"):
    # 在输出之前添加一个带激活函数的全连接神经网络,只在预训练阶段起作用
    with tf.variable_scope("transform"):
      input_tensor = tf.layers.dense(
          input_tensor,
          units=bert_config.hidden_size,
          activation=modeling.get_activation(bert_config.hidden_act),
          kernel_initializer=modeling.create_initializer(
              bert_config.initializer_range))
      input_tensor = modeling.layer_norm(input_tensor)

    # output_weights是和传入的word embedding一样的,这里再添加一个bias
    output_bias = tf.get_variable(
        "output_bias",
        shape=[bert_config.vocab_size],
        initializer=tf.zeros_initializer())
    logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size,max_pred_pre_seq,vocab_size]
    logits = tf.nn.bias_add(logits, output_bias)
    #得出masked词的softmax结果,[batch_size,max_pred_pre_seq,vocab_size]
    log_probs = tf.nn.log_softmax(logits, axis=-1)

    # label_ids表示mask掉的Token的id,下面这部分就是根据真实值计算loss了。
    label_ids = tf.reshape(label_ids, [-1])
    label_weights = tf.reshape(label_weights, [-1])

    one_hot_labels = tf.one_hot(
        label_ids, depth=bert_config.vocab_size, dtype=tf.float32)

    # 但是由于实际MASK的可能不到20,比如只MASK18,那么label_ids有2个0(padding),而label_weights=[1, 1, ...., 0, 0],说明后面两个label_id是padding的,计算loss要去掉,label_weights就是起一个标记作用
    per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
    numerator = tf.reduce_sum(label_weights * per_example_loss)
    denominator = tf.reduce_sum(label_weights) + 1e-5
    loss = numerator / denominator

  return (loss, per_example_loss, log_probs)

 

二、 Next Sentence Prediction

get_next_sentence_output函数用于计算「任务#2」的训练 loss,这部分比较简单,只需要再额外加一层softmax输出即可。输入为 BertModel 的最后一层 pooled_output 输出([batch_size, hidden_size]),因为该任务属于二分类问题,所以只需要每个序列的第一个 token【CLS】即可。

def get_next_sentence_output(bert_config,
                            input_tensor,#pooled_output 输出,shape=[batch_size, hidden_size]
                            labels):
  """Get loss and log probs for the next sentence prediction."""

 # 标签0表示 下一个句子关系成立;标签1表示 下一个句子关系不成立。这个分类器的参数在实际Fine-tuning阶段会丢弃掉
  with tf.variable_scope("cls/seq_relationship"):
  #初始化权重参数,最终的分类结果是只有2个,所以shape=[2,hidden_size]
    output_weights = tf.get_variable(
        "output_weights",
        shape=[2, bert_config.hidden_size],
        initializer=modeling.create_initializer(bert_config.initializer_range))
    output_bias = tf.get_variable(
        "output_bias", shape=[2], initializer=tf.zeros_initializer())
    
    logits = tf.matmul(input_tensor, output_weights, transpose_b=True)#输入与权重相乘,shape=[batch_size,2]
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)#softmax输出:shape=[batch_size,2]
    
    #下面这部分就是根据真实值计算损失loss了
    labels = tf.reshape(labels, [-1])
    one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)
    return (loss, per_example_loss, log_probs)

 

原文:https://www.cnblogs.com/gczr/p/12396992.html

评论(0
© 2014 bubuko.com 版权所有 - 联系我们:wmxa8@hotmail.com
打开技术之扣,分享程序人生!