Trains a LitAutoEncoder model on the MNIST dataset and returns
the trained encoder, decoder, and a flag indicating training completion.
Returns:
Type |
Description |
tuple[Sequential, Sequential, Literal[True]]
|
tuple[Sequential, Sequential, Literal[True]]: A tuple containing the trained encoder,
decoder, and a boolean flag indicating that the model has been successfully trained.
|
Source code in packages/lit-auto-encoder/src/lit_auto_encoder/train_autoencoder.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43 | def train_litautoencoder() -> tuple[Sequential, Sequential, Literal[True]]:
"""Trains a LitAutoEncoder model on the MNIST dataset and returns
the trained encoder, decoder, and a flag indicating training completion.
Returns:
tuple[Sequential, Sequential, Literal[True]]: A tuple containing the trained encoder,
decoder, and a boolean flag indicating that the model has been successfully trained.
""" # noqa: D205
# Define the encoder and decoder architecture
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# Initialize the LitAutoEncoder
autoencoder = LitAutoEncoder(encoder, decoder)
# Load the MNIST dataset
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)
# Initialize the TensorBoard logger
logger = CSVLogger("lightning_logs", name="LitAutoEncoder")
# Train the autoencoder
trainer = L.Trainer(limit_train_batches=100, max_epochs=1, logger=logger)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
is_model_trained = True # Mark model as trained
return encoder, decoder, is_model_trained
|