Пользовательская функция потери в Keras

Я работаю над подходом к инкрементальному классификатору классов изображений, используя CNN в качестве экстрактора функций и полностью подключенного блока для классификации.

во-первых, я сделал тонкую настройку VGG для каждой обученной сети для выполнения новой задачи. Как только сеть подготовлена для новой задачи, я храню несколько примеров для каждого класса, чтобы не забыть, когда новые классы доступны.

когда некоторые классы доступны, я должен вычислить каждый вывод примеров, включенных в примеры для новых классов. Теперь, добавляя нули к выходам для старых классов и добавляя метку, соответствующую каждому новому классу на выходах новых классов, у меня есть мои новые метки, i.e: если ввести 3 новых класса....

вывод старого типа класса:[0.1, 0.05, 0.79, ..., 0 0 0]

вывод нового типа класса:[0.1, 0.09, 0.3, 0.4, ..., 1 0 0] * * последние выходные данные соответствуют классу.

мой вопрос в том, как я могу изменить функцию потери для пользовательского, чтобы тренироваться для новых классов? Функция потерь, которую я хотите реализовать определяется как:

loss function

где потеря дистилляции соответствует выходам для старых классов, чтобы избежать забывания, а потеря классификации соответствует новым классам.

Если вы можете предоставить мне образец кода для изменения функции потерь в keras, было бы неплохо.

спасибо!!!!!

1 ответов


все, что вам нужно сделать, это определить функцию для этого, используя функции бэкэнда keras для вычислений. Функция должна принимать истинные значения и предсказанные значения модели.

теперь, поскольку я не уверен в том, что такое g, q, x an y в вашей функции, я просто создам базовый пример здесь, не заботясь о том, что это значит или является ли это фактической полезной функцией:

import keras.backend as K

def customLoss(yTrue,yPred):
    return K.sum(K.log(yTrue) - K.log(yPred))

все функции бэкэнда можно увидеть здесь: https://keras.io/backend/

после этого скомпилируйте свою модель, используя эту функцию вместо обычной:

model.compile(loss=customLoss, optimizer = .....)