Tensorflow中的buckets介绍以及model_with_buckets接口
bucket就是一种编码思想,bucket的存在是为了减小计算量,从而可以减少模型的训练时间。定义模型创建graph的时候,序列的长度是固定的,之后传入的所有序列都得是定义时指定的长度。这样所有的句子都要padding到指定的长度,很浪费存储空间,计算效率也不高。需要预先指定一系列的buckets。当然,使用dynamic_rnn或rnn这两个接口也可以减少运算时间。bucket是用在使用在cell(input,state)这种古老的方法上的。
bucket思路很简单,就是将输入长度分成不同的间隔,这样数据的在填充时只需要填充到相应的bucket长度即可,不需要都填充到最大长度。比如buckets取[(5,10), (10,20),(20,30)…],每个bucket的第一个数字表示source填充的长度,第二个数字表示target填充的长度。在chatbot中分布表示原始query长度和对话回复内容的长度。例如:‘我爱你’–>‘I love you’,应该会被分配到第一个bucket中,然后‘我爱你’会被pad成长度为5的序列,‘I love you’会被pad成长度为10的序列。在实际中对每个bucket都构造一个模型,然后训练时取相应长度的序列进行,而这些模型将会共享参数。其实这一部分可以参考现在的dynamic_rnn来进行理解,dynamic_rnn是对每个batch的数据将其pad至本batch中长度最大的样本,而bucket则是在数据预处理环节先对数据长度进行聚类操作。
在每次采用SGD更新模型参数时,会根据概率随机地从所有buckets中选择一个,并从中随机选取batch_size个训练样例,并对当前sub-graph中的参数进行优化,每个sub-graph之间权值共享。
Tensorflow中的model_with_buckets
[cc lang=”python”]
tf.contrib.legacy_seq2seq.model_with_buckets(
encoder_inputs,
decoder_inputs,
targets,
weights,
buckets,
seq2seq,
softmax_loss_function=None,
per_example_loss=False,
name=None
)
[/cc]
先看官方介绍的信息:

encoder_inputs: encoder的输入,一个tensor的列表。列表中每一项都是encoder时的一个词(batch)。
decoder_inputs: decoder的输入,同上
targets: 目标值,与decoder_input只相差一个
weights: 目标序列长度值的mask标志,如果是padding则weight=0,否则weight=1
buckets: 就是定义的bucket值,是一个列表:[(5,10), (10,20),(20,30)…]
seq2seq: 定义好的seq2seq模型,可以使用后面介绍的embedding_attention_seq2seq,embedding_rnn_seq2seq,basic_rnn_seq2seq等
softmax_loss_function: 计算误差的函数,(labels, logits),默认为sparse_softmax_cross_entropy_with_logits
per_example_loss: 如果为真,则调用sequence_loss_by_example,返回一个列表,其每个元素就是一个样本的loss值。如果为假,则调用sequence_loss函数,对一个batch的样本只返回一个求和的loss值。
name: Optional name for this operation, defaults to “model_with_buckets”.
返回值:
(outputs, losses)
outputs:是每个bucket的输出,其中每个元素是形状为[batch_size x output_size]或者[batch_size x num_decoder_symbols] 的张量,具体取决于seq2seq 模型。
losses:是一个张量list,代表每个bucket的损失。
关键代码介绍
[cc lang=”Python”]
#保存每个bucket对应的loss和output
losses = []
outputs = []
with ops.name_scope(name, “model_with_buckets”, all_inputs):
#对每个bucket都要选择数据进行构建模型
for j, bucket in enumerate(buckets):
#buckets之间的参数要进行复用
with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
#调用seq2seq进行解码得到输出,这里需要注意的是,encoder_inputs和decoder_inputs是定义好的placeholder,
#都是长度为序列最大长度的列表(也就是最大的那个buckets的长度),按上面的例子,这两个placeholder分别是长度为20和30的列表。
#在构建模型时,对于每个bucket,只取其对应的长度个placeholder即可,如对于(5,10)这个bucket,就取前5/10个placeholder进行构建模型
bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], decoder_inputs[:bucket[1]])
outputs.append(bucket_outputs)
#如果指定per_example_loss则调用sequence_loss_by_example,losses添加的是一个batch_size大小的列表
if per_example_loss:
losses.append(
sequence_loss_by_example(
outputs[-1],
targets[:bucket[1]],
weights[:bucket[1]],
softmax_loss_function=softmax_loss_function))
#否则调用sequence_loss,对上面的结果进行求和,losses添加的是一个值
else:
losses.append(
sequence_loss(
outputs[-1],
targets[:bucket[1]],
weights[:bucket[1]],
softmax_loss_function=softmax_loss_function))
[/cc]
参考:https://zhuanlan.zhihu.com/p/32199930