Функция потери RMSE/ RMSLE в Keras
Я пытаюсь принять участие в моем первом конкурсе Kaggle, где RMSLE
дается как необходимая функция потери. Ибо я не нашел ничего, как реализовать это loss function
Я пытался согласиться на RMSE
. Я знаю, что это было частью Keras
в прошлом, есть ли способ использовать его в последней версии, возможно, с настраиваемой функцией через backend
?
это NN, который я разработал:
from keras.models import Sequential
from keras.layers.core import Dense , Dropout
from keras import regularizers
model = Sequential()
model.add(Dense(units = 128, kernel_initializer = "uniform", activation = "relu", input_dim = 28,activity_regularizer = regularizers.l2(0.01)))
model.add(Dropout(rate = 0.2))
model.add(Dense(units = 128, kernel_initializer = "uniform", activation = "relu"))
model.add(Dropout(rate = 0.2))
model.add(Dense(units = 1, kernel_initializer = "uniform", activation = "relu"))
model.compile(optimizer = "rmsprop", loss = "root_mean_squared_error")#, metrics =["accuracy"])
model.fit(train_set, label_log, batch_size = 32, epochs = 50, validation_split = 0.15)
я попробовал настроить root_mean_squared_error
функция, которую я нашел на GitHub, но для всех I знайте, что синтаксис-это не то, что требуется. Я думаю, что y_true
и y_pred
должен быть определен перед передачей на возврат, но я понятия не имею, как именно, я только начал с программирования на python, и я действительно не так хорош в математике...
from keras import backend as K
def root_mean_squared_error(y_true, y_pred):
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
Я получаю следующую ошибку с этой функцией:
ValueError: ('Unknown loss function', ':root_mean_squared_error')
Спасибо за ваши идеи, я ценю любую помощь!
1 ответов
когда вы используете пользовательскую потерю, вам нужно поставить ее без кавычек, так как вы передаете объект функции, а не строку:
def root_mean_squared_error(y_true, y_pred):
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
model.compile(optimizer = "rmsprop", loss = root_mean_squared_error,
metrics =["accuracy"])