DeSSL.model package

Submodules

DeSSL.model.ALImodel module

class Discriminator_x

Bases: torch.nn.modules.module.Module

The Discriminator of x for CIFAR10 reported by Adversarially Learned Inference.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class Discriminator_x_z(num_classes: int = 10)

Bases: torch.nn.modules.module.Module

The Discriminator of x and z for CIFAR10 reported by Adversarially Learned Inference.

forward(x, z)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class Discriminator_z

Bases: torch.nn.modules.module.Module

The Discriminator of z for CIFAR10 reported by Adversarially Learned Inference.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class Generator_x

Bases: torch.nn.modules.module.Module

The Generator of x for CIFAR10 reported by Adversarially Learned Inference.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class Generator_z

Bases: torch.nn.modules.module.Module

The Generator of z for CIFAR10 reported by Adversarially Learned Inference.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool

DeSSL.model.VAE module

class VAE

Bases: torch.nn.modules.module.Module

Based on VAE.

decode(z)
encode(x)
forward(x, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reparameterize(mu, logvar)
training: bool

DeSSL.model.lenet module

class LeNet5(num_classes: int = 10)

Bases: torch.nn.modules.module.Module

The LeNet-5 for measuring the performance of the algorithms on MNIST.

Parameters

num_classes – The number of categories.

forward(img)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class LeNet5_SVHN(num_classes: int = 10)

Bases: torch.nn.modules.module.Module

The LeNet-5 for measuring the performance of the algorithms on SVHN.

Parameters

num_classes – The number of categories.

forward(img)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool

DeSSL.model.resnet module

resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet

ResNet-101 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.

Parameters
  • pretrained – If True, returns a model pre-trained on ImageNet

  • progress – If True, displays a progress bar of the download to stderr

resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet

ResNet-152 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.

Parameters
  • pretrained – If True, returns a model pre-trained on ImageNet

  • progress – If True, displays a progress bar of the download to stderr

resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet

ResNet-18 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.

Parameters
  • pretrained – If True, returns a model pre-trained on ImageNet

  • progress – If True, displays a progress bar of the download to stderr

resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet

ResNet-34 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.

Parameters
  • pretrained – If True, returns a model pre-trained on ImageNet

  • progress – If True, displays a progress bar of the download to stderr

resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) torchvision.models.resnet.ResNet

ResNet-50 model from “Deep Residual Learning for Image Recognition” <https://arxiv.org/pdf/1512.03385.pdf>_.

Parameters
  • pretrained – If True, returns a model pre-trained on ImageNet

  • progress – If True, displays a progress bar of the download to stderr

resnext101_32x8d(*args, **kwargs)
resnext50_32x4d(*args, **kwargs)
wide_resnet101_2(*args, **kwargs)
wide_resnet50_2(*args, **kwargs)

DeSSL.model.toy module

class Ladder_MLP(input_shape: tuple, num_neurons: List[int], sigma_noise: List[float], input_sigma_noise: float, **kwargs)

Bases: torch.nn.modules.module.Module

Parameters
  • input_shape – The shape of inputs.

  • num_neurons – The list of neurons.

  • sigma_noise – The list of the \(\sigma\) of gaussian noise.

  • input_sigma_noise – THe \(\sigma\) of noise added to supervised learning path.

clear_path(*input)
decoder_path()
forward(path_name, *input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_loss_d(lam_list)
noise_path(*input)
training: bool
class ToyNet(num_classes: int)

Bases: torch.nn.modules.module.Module

The toy NN for measuring the performance of the algorithms on MNIST.

Parameters

num_classes – The number of categories.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
batch_normalization(batch, return_mean_and_std=False)

Module contents