Skip to content

Training

The training process for the autoencoder involves optimizing the autoencoder to minimize reconstruction error on the input data.

Steps

  1. Data Preparation: Load and preprocess the MNIST dataset.
  2. Model Initialization: Instantiate the LitAutoEncoder model with encoder and decoder.
  3. Training Loop: Use PyTorch Lightning's Trainer to handle the training process.

Example

from uv_datascience_project_template.train_autoencoder import train

# Train the autoencoder
train(data_loader, epochs=10, learning_rate=0.001)

Training API

uv_datascience_project_template.train_autoencoder

train_litautoencoder(settings)

Trains a LitAutoEncoder model on the MNIST dataset and returns the trained encoder, decoder, a flag indicating training completion, and the checkpoint path.

PARAMETER DESCRIPTION
settings

The settings object containing model, training, and data configurations.

TYPE: Settings

RETURNS DESCRIPTION
tuple[Sequential, Sequential, Literal[True], str]

tuple[Sequential, Sequential, Literal[True], str]: A tuple containing the trained encoder, decoder, a boolean flag indicating that the model has been successfully trained, and the path to the saved model checkpoint.

Source code in src/uv_datascience_project_template/train_autoencoder.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def train_litautoencoder(settings: Settings) -> tuple[Sequential, Sequential, Literal[True], str]:
    """Trains a LitAutoEncoder model on the MNIST dataset and returns
    the trained encoder, decoder, a flag indicating training completion, and the checkpoint path.

    Args:
        settings (Settings): The settings object containing model, training,
            and data configurations.

    Returns:
        tuple[Sequential, Sequential, Literal[True], str]: A tuple containing the trained encoder,
            decoder, a boolean flag indicating that the model has been successfully trained,
            and the path to the saved model checkpoint.
    """  # noqa: D205

    # Define the encoder and decoder architecture
    # encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
    encoder = nn.Sequential(
        nn.Linear(settings.model.input_dim, settings.model.hidden_dim),
        nn.ReLU(),
        nn.Linear(settings.model.hidden_dim, settings.model.latent_dim),
    )
    # decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
    decoder = nn.Sequential(
        nn.Linear(settings.model.latent_dim, settings.model.hidden_dim),
        nn.ReLU(),
        nn.Linear(settings.model.hidden_dim, settings.model.input_dim),
    )

    # Initialize the LitAutoEncoder
    autoencoder = LitAutoEncoder(encoder, decoder, learning_rate=settings.training.learning_rate)

    # Load the MNIST dataset
    # Note: Ensure the path is set correctly in settings.data.mnist_data_path
    # dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
    dataset = MNIST(settings.data.mnist_data_path, download=True, transform=ToTensor())
    train_loader = utils.data.DataLoader(dataset)

    # Initialize the TensorBoard logger
    logger = CSVLogger("lightning_logs", name="LitAutoEncoder")

    # Define ModelCheckpoint callback
    checkpoint_dir = os.path.join(
        settings.data.mnist_data_path, "lightning_logs", "LitAutoEncoder", "checkpoints"
    )
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir, filename="autoencoder-{epoch:02d}", save_last=True
    )

    # Train the autoencoder
    # trainer = L.Trainer(limit_train_batches=100, max_epochs=1, logger=logger)
    trainer = L.Trainer(
        limit_train_batches=settings.training.limit_train_batches,
        max_epochs=settings.training.epochs,
        logger=logger,
        callbacks=[checkpoint_callback],
    )
    trainer.fit(model=autoencoder, train_dataloaders=train_loader)

    # Save a checkpoint explicitly at the end of training
    final_checkpoint_path = os.path.join(
        str(settings.data.mnist_data_path),
        "lightning_logs",
        "LitAutoEncoder",
        "checkpoints",
        "final_autoencoder.ckpt",
    )
    os.makedirs(os.path.dirname(final_checkpoint_path), exist_ok=True)
    trainer.save_checkpoint(final_checkpoint_path)

    is_model_trained = True  # Mark model as trained

    # Return the path to the best model checkpoint
    return autoencoder.encoder, autoencoder.decoder, is_model_trained, final_checkpoint_path