Skip to content

Training

The training process for the autoencoder involves optimizing the autoencoder to minimize reconstruction error on the input data.

Steps

  1. Data Preparation: Load and preprocess the dataset.
  2. Model Initialization: Instantiate the LitAutoEncoder model.
  3. Training Loop: Use PyTorch Lightning's Trainer to handle the training process.

Example

from uv_datascience_project_template.train_autoencoder import train

# Train the autoencoder
train(data_loader, epochs=10, learning_rate=0.001)

Training API

uv_datascience_project_template.train_autoencoder

train_litautoencoder()

Trains a LitAutoEncoder model on the MNIST dataset and returns the trained encoder, decoder, and a flag indicating training completion.

RETURNS DESCRIPTION
tuple[Sequential, Sequential, Literal[True]]

tuple[Sequential, Sequential, Literal[True]]: A tuple containing the trained encoder, decoder, and a boolean flag indicating that the model has been successfully trained.

Source code in src/uv_datascience_project_template/train_autoencoder.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def train_litautoencoder() -> tuple[Sequential, Sequential, Literal[True]]:
    """Trains a LitAutoEncoder model on the MNIST dataset and returns
    the trained encoder, decoder, and a flag indicating training completion.

    Returns:
        tuple[Sequential, Sequential, Literal[True]]: A tuple containing the trained encoder,
            decoder, and a boolean flag indicating that the model has been successfully trained.
    """  # noqa: D205

    # Define the encoder and decoder architecture
    encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
    decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    # Initialize the LitAutoEncoder
    autoencoder = LitAutoEncoder(encoder, decoder)

    # Load the MNIST dataset
    dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
    train_loader = utils.data.DataLoader(dataset)

    # Initialize the TensorBoard logger
    logger = CSVLogger("lightning_logs", name="LitAutoEncoder")

    # Train the autoencoder
    trainer = L.Trainer(limit_train_batches=100, max_epochs=1, logger=logger)
    trainer.fit(model=autoencoder, train_dataloaders=train_loader)

    is_model_trained = True  # Mark model as trained

    return encoder, decoder, is_model_trained