Пользовательская функция потери в Keras
Я работаю над подходом к инкрементальному классификатору классов изображений, используя CNN в качестве экстрактора функций и полностью подключенного блока для классификации.
во-первых, я сделал тонкую настройку VGG для каждой обученной сети для выполнения новой задачи. Как только сеть подготовлена для новой задачи, я храню несколько примеров для каждого класса, чтобы не забыть, когда новые классы доступны.
когда некоторые классы доступны, я должен вычислить каждый вывод примеров, включенных в примеры для новых классов. Теперь, добавляя нули к выходам для старых классов и добавляя метку, соответствующую каждому новому классу на выходах новых классов, у меня есть мои новые метки, i.e: если ввести 3 новых класса....
вывод старого типа класса:[0.1, 0.05, 0.79, ..., 0 0 0]
вывод нового типа класса:[0.1, 0.09, 0.3, 0.4, ..., 1 0 0]
* * последние выходные данные соответствуют классу.
мой вопрос в том, как я могу изменить функцию потери для пользовательского, чтобы тренироваться для новых классов? Функция потерь, которую я хотите реализовать определяется как:
где потеря дистилляции соответствует выходам для старых классов, чтобы избежать забывания, а потеря классификации соответствует новым классам.
Если вы можете предоставить мне образец кода для изменения функции потерь в keras, было бы неплохо.
спасибо!!!!!
1 ответов
все, что вам нужно сделать, это определить функцию для этого, используя функции бэкэнда keras для вычислений. Функция должна принимать истинные значения и предсказанные значения модели.
теперь, поскольку я не уверен в том, что такое g, q, x an y в вашей функции, я просто создам базовый пример здесь, не заботясь о том, что это значит или является ли это фактической полезной функцией:
import keras.backend as K
def customLoss(yTrue,yPred):
return K.sum(K.log(yTrue) - K.log(yPred))
все функции бэкэнда можно увидеть здесь: https://keras.io/backend/
после этого скомпилируйте свою модель, используя эту функцию вместо обычной:
model.compile(loss=customLoss, optimizer = .....)