PyTorch: как использовать DataLoaders для пользовательских наборов данных
как использовать torch.utils.data.Dataset и torch.utils.data.DataLoader по вашим собственным данным (не только torchvision.datasets)?
есть ли способ использовать встроенный DataLoaders, который они используют на TorchVisionDatasets для использования в любом наборе данных?
2 ответов
Да, это возможно. Просто создайте объекты самостоятельно, например
import torch.utils.data as data_utils
train = data_utils.TensorDataset(features, targets)
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)
здесь features и targets являются тензоры. features должно быть 2-D, т. е. Матрица, где каждая строка представляет собой один обучающий образец, и targets может быть 1-D или 2-D, в зависимости от того, пытаетесь ли вы предсказать скаляр или вектор.
надеюсь, что это поможет!
редактировать: ответ на вопрос @sarthak
в основном да. Если вы создайте объект типа TensorData, затем конструктор исследует, являются ли первые измерения тензора объектов (который фактически называется data_tensor) и целевой тензор (называемый target_tensor) имеют одинаковую длину:
assert data_tensor.size(0) == target_tensor.size(0)
однако, если вы хотите передать эти данные в нейронную сеть впоследствии, вам нужно быть осторожным. В то время как слои свертки работают с данными, подобными вашим, (Я думаю), все другие типы слоев ожидают, что данные будут даны в матричной форме. Итак, если вы сталкиваетесь с такой проблемой, тогда простым решением было бы преобразовать ваш 4D-dataset (заданный как какой-то тензор, например FloatTensor) в матрицу с помощью метода view. Для вашего набора данных 5000xnxnx3, это будет выглядеть так:
2d_dataset = 4d_dataset.view(5000, -1)
(стоимость -1 говорит PyTorch, чтобы выяснить длину второго измерения автоматически.)
вы можете легко сделать это, расширяя data.Dataset класса.
Согласно API, все, что вам нужно сделать, это реализовать две функции: __getitem__ и __len__.
затем вы можете обернуть набор данных с помощью DataLoader, как показано в API и в ответе @pho7.
Я думаю ImageFolder класс-это ссылка. См. код здесь.