Skip to content

Autoencoder

The LitAutoEncoder is a PyTorch Lightning module designed for unsupervised learning tasks. It consists of an encoder and a decoder network.

Key Features

  • Encoder: Compresses input data into a latent representation.
  • Decoder: Reconstructs the input data from the latent representation.
  • Loss Function: Mean Squared Error (MSE) is used to measure reconstruction quality.

Autoencoder Class API

uv_datascience_project_template.lit_auto_encoder

LitAutoEncoder(encoder, decoder)

Bases: LightningModule

A simple autoencoder model.

PARAMETER DESCRIPTION
encoder

The encoder component, responsible for encoding input data.

TYPE: Sequential

decoder

The decoder component, responsible for decoding encoded data.

TYPE: Sequential

Source code in src/uv_datascience_project_template/lit_auto_encoder.py
16
17
18
19
def __init__(self, encoder: nn.Sequential, decoder: nn.Sequential) -> None:
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

configure_optimizers()

Configure the Adam optimizer.

Source code in src/uv_datascience_project_template/lit_auto_encoder.py
42
43
44
45
def configure_optimizers(self) -> optim.Adam:
    """Configure the Adam optimizer."""
    optimizer = optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

training_step(batch, batch_idx)

Performs a single training step for the model.

PARAMETER DESCRIPTION
batch

A tuple containing the input data (x) and the corresponding labels (y).

TYPE: Tuple[Tensor, Tensor]

batch_idx

The index of the current batch.

TYPE: int

RETURNS DESCRIPTION
Tensor

The computed loss for the current training step.

TYPE: Tensor

Source code in src/uv_datascience_project_template/lit_auto_encoder.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
    """Performs a single training step for the model.

    Args:
        batch (Tuple[Tensor, Tensor]): A tuple containing the input data (x) and
            the corresponding labels (y).
        batch_idx (int): The index of the current batch.

    Returns:
        Tensor: The computed loss for the current training step.
    """

    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = nn.functional.mse_loss(x_hat, x)
    # Logging to TensorBoard (if installed) by default
    # self.log("train_loss", loss)
    return loss