Ранняя остановка с перекрестной проверкой Keras и sklearn GridSearchCV

Я хочу реализовать раннюю остановку с Keras и sklean's GridSearchCV.

пример рабочего кода ниже изменен с как сетчатый Поиск гиперпараметров для моделей глубокого обучения в Python с Keras. Набор данных может быть скачать здесь.

модификация добавляет Keras EarlyStopping класс обратного вызова для предотвращения чрезмерной подгонки. Для этого, чтобы быть эффективным, требуется

3 ответов


[ответ после того, как вопрос был отредактирован и уточнен:]

прежде чем бросаться в вопросы реализации, всегда полезно подумать о методологии и самой задаче; возможно, смешение ранней остановки с процедурой перекрестной проверки не хорошая идея.

давайте составим пример, чтобы выделить аргумент.

предположим, что вы действительно используете раннюю остановку со 100 эпохами и 5-кратным крестом проверка (CV) для выбора hyperparameter. Предположим также, что вы в конечном итоге с hyperparameter набор X дает лучшую производительность, скажем 89.3% бинарные точность классификации.

теперь предположим, что второй-лучший набор hyperparameter, Г, дает 89.2% точности. Внимательно изучая индивидуальные складки CV, вы видите, что в лучшем случае X, 3 из 5 складок CV исчерпали максимальные 100 эпох, в то время как в других 2 ранних остановках, скажем, в 89 и 93 эпохах соответственно.

каков был бы ваш вывод из такого эксперимента?

возможно, вы оказались бы в неубедительными ситуация; дальнейшие эксперименты могут показать, какой на самом деле лучший набор гиперпараметров, при условии, конечно, что вы подумали бы посмотрите на эти детали результатов в первую очередь. И излишне говорить, что если бы все это было автоматизировано через обратный вызов, вы могли бы пропустить свою лучшую модель, несмотря на то, что вы бы на самом деле пробовал.


вся идея CV неявно основана на аргументе" все остальные равны " (который, конечно, никогда не верен на практике, только аппроксимирован наилучшим образом). Если вы считаете, что количество эпох должно быть гиперпараметром, просто включите его явно в свое резюме как таковое, а не вставляя его через заднюю дверь ранней остановки, таким образом, возможно, поставив под угрозу весь процесс (не говоря уже о том, что ранняя остановка сама hyperparameter, patience).

не смешивание этих двух методов не означает, конечно, что вы не можете их использовать последовательно: как только вы получили свои лучшие гиперпараметры через CV, вы всегда можете использовать раннюю остановку при установке модель во всем наборе обучения (при условии, конечно, что у вас есть отдельный набор проверки).


поле глубоких нейронных сетей все еще (очень) молодо, и это правда, что ему еще предстоит установить свои рекомендации по "лучшей практике"; добавьте тот факт, что благодаря удивительному сообществу есть все виды инструментов, доступных в реализациях с открытым исходным кодом, и вы можете легко найти себя в (по общему признанию, заманчивое) положение смешивания вещей только потому, что они происходят иметься. Я вовсе не хочу сказать, что именно это вы пытаетесь сделать здесь - я просто призываю к большей осторожности при объединении идей, которые, возможно, не были разработаны для совместной работы...


вот как это сделать только один раскол.

fit_params['cl__validation_data'] = (X_val, y_val)
X_final = np.concatenate((X_train, X_val))
y_final = np.concatenate((y_train, y_val))
splits = [(range(len(X_train)), range(len(X_train), len(X_final)))]

GridSearchCV(estimator=model, param_grid=param_grid, cv=splits)I

Если вы хотите больше разбиений, вы можете использовать 'cl__validation_split' с фиксированным соотношением и построить разбиения, которые соответствуют этим критериям.

это может быть слишком параноидально, но я не использую набор данных ранней остановки в качестве набора данных проверки, поскольку он косвенно использовался для создания модели.

Я также думаю, что если вы используете раннюю остановку с вашей окончательной моделью, то это также должно быть сделано, когда вы выполнение поиска гиперпараметров.


[старый ответ, прежде чем вопрос был отредактирован и уточнен-см. обновленный и принятый ответ выше]

Я не уверен, что понял вашу точную проблему (ваш вопрос совершенно неясен, и вы включаете много несвязанных деталей, что никогда не хорошо, когда задаете такой вопрос - см. здесь).

не обязательно (и на самом деле не должен) включать любые аргументы о данных проверки в ваш model = KerasClassifier() вызов функции (интересно, почему ты не чувствуешь того же обучение сведения здесь, тоже). Ваш grid.fit() позаботится как о тренировке и складки проверки. Таким образом, если вы хотите сохранить значения гиперпараметра, включенные в ваш пример, этот вызов функции должен быть просто

model = KerasClassifier(build_fn=create_model, 
                        epochs=100, batch_size=32,
                        shuffle=True,
                        verbose=1)

вы можете увидеть некоторые четкие и хорошо объясненные примеры использования GridSearchCV С Керрас здесь.