DeSSL.model package¶
Submodules¶
DeSSL.model.ALImodel module¶
- class Discriminator_x¶
Bases:
torch.nn.modules.module.ModuleThe Discriminator of x for CIFAR10 reported by Adversarially Learned Inference.
- forward(input)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class Discriminator_x_z(num_classes: int = 10)¶
Bases:
torch.nn.modules.module.ModuleThe Discriminator of x and z for CIFAR10 reported by Adversarially Learned Inference.
- forward(x, z)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class Discriminator_z¶
Bases:
torch.nn.modules.module.ModuleThe Discriminator of z for CIFAR10 reported by Adversarially Learned Inference.
- forward(input)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class Generator_x¶
Bases:
torch.nn.modules.module.ModuleThe Generator of x for CIFAR10 reported by Adversarially Learned Inference.
- forward(input)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class Generator_z¶
Bases:
torch.nn.modules.module.ModuleThe Generator of z for CIFAR10 reported by Adversarially Learned Inference.
- forward(input)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
DeSSL.model.VAE module¶
- class VAE¶
Bases:
torch.nn.modules.module.ModuleBased on VAE.
- decode(z)¶
- encode(x)¶
- forward(x, y)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- reparameterize(mu, logvar)¶
- training: bool¶
DeSSL.model.lenet module¶
- class LeNet5(num_classes: int = 10)¶
Bases:
torch.nn.modules.module.ModuleThe LeNet-5 for measuring the performance of the algorithms on MNIST.
- Parameters
num_classes – The number of categories.
- forward(img)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class LeNet5_SVHN(num_classes: int = 10)¶
Bases:
torch.nn.modules.module.ModuleThe LeNet-5 for measuring the performance of the algorithms on SVHN.
- Parameters
num_classes – The number of categories.
- forward(img)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
DeSSL.model.resnet module¶
- resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet¶
ResNet-101 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.
- Parameters
pretrained – If True, returns a model pre-trained on ImageNet
progress – If True, displays a progress bar of the download to stderr
- resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet¶
ResNet-152 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.
- Parameters
pretrained – If True, returns a model pre-trained on ImageNet
progress – If True, displays a progress bar of the download to stderr
- resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet¶
ResNet-18 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.
- Parameters
pretrained – If True, returns a model pre-trained on ImageNet
progress – If True, displays a progress bar of the download to stderr
- resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet¶
ResNet-34 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.
- Parameters
pretrained – If True, returns a model pre-trained on ImageNet
progress – If True, displays a progress bar of the download to stderr
- resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet¶
ResNet-50 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.
- Parameters
pretrained – If True, returns a model pre-trained on ImageNet
progress – If True, displays a progress bar of the download to stderr
- resnext101_32x8d(*args, **kwargs)¶
- resnext50_32x4d(*args, **kwargs)¶
- wide_resnet101_2(*args, **kwargs)¶
- wide_resnet50_2(*args, **kwargs)¶
DeSSL.model.toy module¶
- class Ladder_MLP(input_shape: tuple, num_neurons: List[int], sigma_noise: List[float], input_sigma_noise: float, **kwargs)¶
Bases:
torch.nn.modules.module.Module- Parameters
input_shape – The shape of inputs.
num_neurons – The list of neurons.
sigma_noise – The list of the \(\sigma\) of gaussian noise.
input_sigma_noise – THe \(\sigma\) of noise added to supervised learning path.
- clear_path(*input)¶
- decoder_path()¶
- forward(path_name, *input)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_loss_d(lam_list)¶
- noise_path(*input)¶
- training: bool¶
- class ToyNet(num_classes: int)¶
Bases:
torch.nn.modules.module.ModuleThe toy NN for measuring the performance of the algorithms on MNIST.
- Parameters
num_classes – The number of categories.
- forward(input)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- batch_normalization(batch, return_mean_and_std=False)¶