Как уменьшить потребление памяти в цикле в TensorFlow?

У меня есть цикл в TensorFlow, который выглядит так:

with tf.device("/gpu:1"):
    losses = []

    for target, output in zip(targets, lstm_outputs):
        logits = tf.matmul(W, output) + b
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, target)
        losses.append(loss)

    total_loss = tf.add_n(losses)

Я получаю ошибку OOM при выделении градиентов для этого слоя, так как каждое умножение матрицы является другой операцией в графе, занимающем память. Есть ли способ предотвратить выделение TensorFlow всех этих операций одновременно?

1 ответов


это сложный график для оптимизации TensorFlow, так как активации из каждого слоя должны быть сохранены для агрегирования одного градиента для W. Одна из возможностей-пройти экспериментальные