PyTorch: как использовать DataLoaders для пользовательских наборов данных
как использовать torch.utils.data.Dataset
и torch.utils.data.DataLoader
по вашим собственным данным (не только torchvision.datasets
)?
есть ли способ использовать встроенный DataLoaders
, который они используют на TorchVisionDatasets
для использования в любом наборе данных?
2 ответов
Да, это возможно. Просто создайте объекты самостоятельно, например
import torch.utils.data as data_utils
train = data_utils.TensorDataset(features, targets)
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)
здесь features
и targets
являются тензоры. features
должно быть 2-D, т. е. Матрица, где каждая строка представляет собой один обучающий образец, и targets
может быть 1-D или 2-D, в зависимости от того, пытаетесь ли вы предсказать скаляр или вектор.
надеюсь, что это поможет!
редактировать: ответ на вопрос @sarthak
в основном да. Если вы создайте объект типа TensorData
, затем конструктор исследует, являются ли первые измерения тензора объектов (который фактически называется data_tensor
) и целевой тензор (называемый target_tensor
) имеют одинаковую длину:
assert data_tensor.size(0) == target_tensor.size(0)
однако, если вы хотите передать эти данные в нейронную сеть впоследствии, вам нужно быть осторожным. В то время как слои свертки работают с данными, подобными вашим, (Я думаю), все другие типы слоев ожидают, что данные будут даны в матричной форме. Итак, если вы сталкиваетесь с такой проблемой, тогда простым решением было бы преобразовать ваш 4D-dataset (заданный как какой-то тензор, например FloatTensor
) в матрицу с помощью метода view
. Для вашего набора данных 5000xnxnx3, это будет выглядеть так:
2d_dataset = 4d_dataset.view(5000, -1)
(стоимость -1
говорит PyTorch, чтобы выяснить длину второго измерения автоматически.)
вы можете легко сделать это, расширяя data.Dataset
класса.
Согласно API, все, что вам нужно сделать, это реализовать две функции: __getitem__
и __len__
.
затем вы можете обернуть набор данных с помощью DataLoader, как показано в API и в ответе @pho7.
Я думаю ImageFolder
класс-это ссылка. См. код здесь.