DeSSL.data package

Submodules

DeSSL.data.dataset module

class SemiDataLoader(label_loader: torch.utils.data.dataloader.DataLoader, unlabel_loader: torch.utils.data.dataloader.DataLoader, num_iteration: int)

Bases: object

class SemiDataset(root: str, num_labels_per_class: int, dataset: Type[torch.utils.data.dataset.Dataset], num_classes: int, label_transform: Optional[Callable] = None, unlabel_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, norm: Optional[Callable] = None, download: bool = False, include_labeled_data: bool = True)

Bases: object

A class representing a semi-supervised dataset.

Example

>>> from torchvision.datasets import CIFAR10
>>> from torchvision import transforms as tf
>>> root = '...'
>>> # initialize a semi CIFAR10 with 1000 labeled images for each class.
>>> semi_dataset = SemiDataset(root, 1000, CIFAR10, 10, norm=tf.Compose([tf.ToTensor(),
>>>                                                                      tf.Normalize((0.4914, 0.4822, 0.4465),
>>>                                                                                   (0.2023, 0.1994, 0.2010))]))
Parameters
  • root – The root directory where the dataset exists or will be saved.

  • num_labels_per_class – The number of each class.

  • dataset – An instantiable class representing a Dataset.

  • num_classes – The number of class.

  • label_transform – A function/transform that takes in a labeled image and returns a transformed version. E.g, transforms.RandomCrop.

  • unlabel_transform

    A function/transform that takes in an unlabeled image and returns a transformed version. E.g, transforms.RandomCrop.

  • test_transform

    A function/transform that takes in a test image and returns a transformed version. E.g, transforms.RandomCrop.

  • norm – Normalization after all transform.

  • download – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • include_labeled_data – If true, unlabeled data will include labeled data.

Returns

A semi-supervised dataset.

get_dataloader(label_batch_size: int, unlabel_batch_size: Optional[int] = None, test_batch_size: Optional[int] = None, num_iteration: Optional[int] = None, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, return_num_classes: bool = True) Union[Tuple[DeSSL.data.dataset.SemiDataLoader, torch.utils.data.dataloader.DataLoader], Tuple[DeSSL.data.dataset.SemiDataLoader, torch.utils.data.dataloader.DataLoader, int]]

Get Dataloader.

Parameters
  • label_batch_size – The batch size of labeled data.

  • unlabel_batch_size – The batch size of unlabeled data. If None, use label_batch_size instead.

  • test_batch_size – The batch size of testing data. If None, use label_batch_size + unlabel_batch_size instead.

  • num_iteration – The number of iteration for each epoch. If None, use the number of iteration of supervised dataset instead.

  • shuffle – Set to True to have the training data reshuffled at every epoch.

  • num_workers – How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.

  • pin_memory – If True, the data loader will copy Tensors into CUDA pinned memory before returning them.

  • drop_last – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.

  • return_num_classes – If return number of classes as the last return value.

Returns

A semi-supervised training dataloader and a testing dataloader.

semi_cifar10(*args, **kwargs) DeSSL.data.dataset.SemiDataset

The partial function is an initialization of SemiDataset which has dataset=CIFAR10, num_classes=10, norm=tf.Compose([tf.ToTensor(), tf.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) supplied.

semi_imagenet(*args, **kwargs) DeSSL.data.dataset.SemiDataset

The partial function is an initialization of SemiDataset which has dataset=ImageFolder, num_classes=1000, norm=tf.Compose([tf.ToTensor(), tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) supplied.

semi_mnist(*args, **kwargs) DeSSL.data.dataset.SemiDataset

The partial function is an initialization of SemiDataset which has dataset=MNIST, num_classes=10, norm=tf.Compose([tf.ToTensor(), tf.Normalize((0.1307), (0.3081))]) supplied.

semi_svhn(*args, **kwargs) DeSSL.data.dataset.SemiDataset

The partial function is an initialization of SemiDataset which has dataset=SVHN, num_classes=10, norm=tf.Compose([tf.ToTensor(), tf.Normalize((0.4390, 0.4443, 0.4692), (0.1189, 0.1222, 0.1049))]) supplied.

Example

>>> from DeSSL.data import semi_cifar10
>>> root = '...'
>>> # initialize a semi CIFAR10 with 1000 labeled images for each class.
>>> semi_cifar = semi_cifar10(root, 1000)

or:

>>> from DeSSL.data import SEMI_DATASET_REGISTRY
>>> root = '...'
>>> # initialize a semi CIFAR10 with 1000 labeled images for each class.
>>> semi_cifar = SEMI_DATASET_REGISTRY('semi_cifar10')(root, 1000)

DeSSL.data.accelerator module

accelerated_fashionmnist(*args, **kwargs) Type[torchvision.datasets.mnist.MNIST]

The partial function is an initialization of AcceleratedMNIST which has mnist=FashionMNIST, mean=0.286, std=0.352 supplied.

accelerated_mnist(*args, **kwargs) Type[torchvision.datasets.mnist.MNIST]

The partial function is an initialization of AcceleratedMNIST which has mnist=MNIST, mean=0.1307, std=0.3081 supplied.

Example

>>> from DeSSL.data import accelerated_mnist
>>> from torchvision.datasets import MNIST
>>> FastMNIST = accelerated_mnist(torch.device('cuda:0'))

or:

>>> from DeSSL.data import ACCELERATOR_REGISTRY
>>> FastMNIST = ACCELERATOR_REGISTRY('mnist')(torch.device('cuda:0'))
accelerator(device: torch.device, mnist: Type[torchvision.datasets.mnist.MNIST], mean: float = 0.5, std: float = 1.0) Type[torchvision.datasets.mnist.MNIST]

An accelerator for MNIST-like dataset.

Note

The accelerator will transmit all data to a device(e.g. GPU) in initialization. As a result, no subprocess use for data loading. num_workers=0 is necessary to dataloader.

Example

>>> from DeSSL.data import accelerator
>>> from torchvision.datasets import MNIST
>>> FastMNIST = accelerator(torch.device('cuda:0'), MNIST, 0.1307, 0.3081)
Parameters
  • device – the device on which a dataset will be allocated.

  • mnist – a MNIST-like dataset.

  • mean – The mean value of dataset.

  • std – The std value of dataset.

Returns

An AcceleratedMNIST class.

Module contents