Skip to content

Autoencoder

LitAutoEncoder

Bases: LightningModule

A simple autoencoder model.

Source code in src/uv_datascience_project_template/lit_auto_encoder.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class LitAutoEncoder(L.LightningModule):
    """A simple autoencoder model."""

    def __init__(self, encoder, decoder) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx) -> Tensor:
        """Training step that defines the train loop.

        Training_step defines the train loop, it is independent of forward.
        """

        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        # self.log("train_loss", loss)
        return loss

    def configure_optimizers(self) -> optim.Adam:
        """Configure the Adam optimizer."""
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

configure_optimizers()

Configure the Adam optimizer.

Source code in src/uv_datascience_project_template/lit_auto_encoder.py
29
30
31
32
def configure_optimizers(self) -> optim.Adam:
    """Configure the Adam optimizer."""
    optimizer = optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

training_step(batch, batch_idx)

Training step that defines the train loop.

Training_step defines the train loop, it is independent of forward.

Source code in src/uv_datascience_project_template/lit_auto_encoder.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def training_step(self, batch, batch_idx) -> Tensor:
    """Training step that defines the train loop.

    Training_step defines the train loop, it is independent of forward.
    """

    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = nn.functional.mse_loss(x_hat, x)
    # Logging to TensorBoard (if installed) by default
    # self.log("train_loss", loss)
    return loss