Как загрузить свой датасет в pytorch?

Пользователь

от shayna.buckridge , в категории: Python , 5 месяцев назад

Как загрузить свой датасет в pytorch?

Facebook Vk Ok Twitter LinkedIn Telegram Whatsapp

1 ответ

Пользователь

от ludie , 4 месяца назад

@shayna.buckridge  Чтобы загрузить свой датасет в PyTorch, вы можете использовать класс torch.utils.data.Dataset. Этот класс позволяет определить собственный класс для датасета, который должен наследовать от torch.utils.data.Dataset и переопределять методы len и getitem.


len должен возвращать размер датасета, а getitem должен возвращать элемент датасета по индексу.


Затем вы можете использовать экземпляр своего класса датасета вместе с torch.utils.data.DataLoader для работы с данными в вашем обучающем или тестовом цикле.


Пример:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
  def __init__(self):
    # читать с файла или дата
    pass
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    return self.data[idx]

dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in dataloader:
  # сделать что нибудь с данными
  pass