Training function to initialize models and perform training.
Source code in src/uv_datascience_project_template/train_autoencoder.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 | def train_litautoencoder() -> tuple[Sequential, Sequential, Literal[True]]:
"""Training function to initialize models and perform training."""
# Define the encoder and decoder architecture
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# Initialize the LitAutoEncoder
autoencoder = LitAutoEncoder(encoder, decoder)
# Load the MNIST dataset
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)
# Initialize the TensorBoard logger
logger = CSVLogger("lightning_logs", name="LitAutoEncoder")
# Train the autoencoder
trainer = L.Trainer(limit_train_batches=100, max_epochs=1, logger=logger)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
is_model_trained = True # Mark model as trained
return encoder, decoder, is_model_trained
|