Index ソフト・ハード PyTorch | MyDataset Sample |
MNIST FashionMNIST 機能・要件 構成・方式 タスク ライブラリ 導入 sample datasets mnistSample |
MyDataset
class MyDataset(torch.utils.data.Dataset): # 継承
def __init__(self, data, label, transform=None): # コンストラクタのオーバーライド
self.transform = transform # 引数の処理
self.data = data # 基本的にテンソルは Channel、Height、Width の三次元で考える。
self.data_num = len(data)
self.label = label
def __len__(self): # 新しいメソッドの定義、len で、引数の長さを返す。
return self.data_num
def __getitem__(self, idx): # インデックスを指定できるように __getitem__ を定義
if self.transform:
out_data = self.transform(self.data)[0][idx] # transform の出力からCH0を指定
out_label = self.label[idx]
else:
out_data = self.data[idx]
out_label = self.label[idx]
return out_data, out_label
MNIST
# myDataset.py
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision.io import read_image
class MyDataset(Dataset): # 継承
# FashionMNIST画像の場合はディレクトリ img_dir に保存
# それらのラベルは CSV ファイルに個別に保存(annotations_file)
# FashionMNISTの場合、いずれ引数の処理も引数の処理は無し(済)。
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_dir = img_dir # 基本 テンソルは Channel Height Width の三次元で考える。
self.img_labels = pd.read_csv(annotations_file)
self.transform = transform # 引数の処理
self.target_transform = target_transform # 引数の処理
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
FashionMNIST
# myDataset.py
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision.io import read_image
class MyDataset(Dataset): # 継承
# FashionMNIST画像の場合はディレクトリ img_dir に保存
# それらのラベルは CSV ファイルに個別に保存(annotations_file)
# FashionMNISTの場合、いずれ引数の処理も引数の処理は無し(済)。
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_dir = img_dir # 基本 テンソルは Channel Height Width の三次元で考える。
self.img_labels = pd.read_csv(annotations_file)
self.transform = transform # 引数の処理
self.target_transform = target_transform # 引数の処理
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
|
All Rights Reserved. Copyright (C) ITCL |