Jupyter Notebook - App¶
Run Application interactively without using FastAPI¶
In [ ]:
Copied!
import warnings
from typing import List, Tuple
from pydantic import BaseModel, Field
from torch import rand
from torch.nn import Sequential
from typing_extensions import Annotated
from lit_auto_encoder.auto_encoder import LitAutoEncoder
from lit_auto_encoder.train_autoencoder import train_litautoencoder
warnings.filterwarnings("ignore")
import warnings
from typing import List, Tuple
from pydantic import BaseModel, Field
from torch import rand
from torch.nn import Sequential
from typing_extensions import Annotated
from lit_auto_encoder.auto_encoder import LitAutoEncoder
from lit_auto_encoder.train_autoencoder import train_litautoencoder
warnings.filterwarnings("ignore")
In [ ]:
Copied!
# Input validation model
class NumberFakeImages(BaseModel):
n_fake_images: Annotated[int, Field(ge=1, le=10)] # type: ignore # Between 1 and 10 fake images allowed
# Input validation model
class NumberFakeImages(BaseModel):
n_fake_images: Annotated[int, Field(ge=1, le=10)] # type: ignore # Between 1 and 10 fake images allowed
Train the Model¶
In [ ]:
Copied!
def train_model() -> Tuple[Sequential, Sequential]:
"""Train the autoencoder model.
Returns:
tuple[Sequential, Sequential]: Encoder and decoder models.
"""
encoder, decoder, _is_model_trained = train_litautoencoder()
return encoder, decoder
def train_model() -> Tuple[Sequential, Sequential]:
"""Train the autoencoder model.
Returns:
tuple[Sequential, Sequential]: Encoder and decoder models.
"""
encoder, decoder, _is_model_trained = train_litautoencoder()
return encoder, decoder
In [ ]:
Copied!
# Train encoder and decoder
encoder, decoder = train_model()
# Train encoder and decoder
encoder, decoder = train_model()
GPU available: True (mps), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs | Name | Type | Params | Mode ----------------------------------------------- 0 | encoder | Sequential | 50.4 K | train 1 | decoder | Sequential | 51.2 K | train ----------------------------------------------- 101 K Trainable params 0 Non-trainable params 101 K Total params 0.407 Total estimated model params size (MB) 8 Modules in train mode 0 Modules in eval mode
Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 203.43it/s, v_num=80]
`Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 198.91it/s, v_num=80]
Create Embeddings¶
In [ ]:
Copied!
def create_embed(
input_data: NumberFakeImages, encoder: Sequential, decoder: Sequential, checkpoint_path: str
) -> List[List[float]]:
"""Embed fake images using the trained autoencoder.
Args:
input_data (NumberFakeImages): Input data containing the number of fake images to embed.
encoder (Sequential): Encoder model.
decoder (Sequential): Decoder model.
checkpoint_path (str): Path to the checkpoint file.
Returns:
List[List[float]]: A list containing the embeddings of each fake images as a list.
"""
n_fake_images = input_data.n_fake_images
# Load the trained autoencoder from the checkpoint
autoencoder = LitAutoEncoder.load_from_checkpoint(
checkpoint_path, encoder=encoder, decoder=decoder
)
encoder_model = autoencoder.encoder
encoder_model.eval()
# Generate fake image embeddings based on user input
fake_image_batch = rand(n_fake_images, 28 * 28, device=autoencoder.device)
embeddings = encoder_model(fake_image_batch)
return embeddings.tolist()
def create_embed(
input_data: NumberFakeImages, encoder: Sequential, decoder: Sequential, checkpoint_path: str
) -> List[List[float]]:
"""Embed fake images using the trained autoencoder.
Args:
input_data (NumberFakeImages): Input data containing the number of fake images to embed.
encoder (Sequential): Encoder model.
decoder (Sequential): Decoder model.
checkpoint_path (str): Path to the checkpoint file.
Returns:
List[List[float]]: A list containing the embeddings of each fake images as a list.
"""
n_fake_images = input_data.n_fake_images
# Load the trained autoencoder from the checkpoint
autoencoder = LitAutoEncoder.load_from_checkpoint(
checkpoint_path, encoder=encoder, decoder=decoder
)
encoder_model = autoencoder.encoder
encoder_model.eval()
# Generate fake image embeddings based on user input
fake_image_batch = rand(n_fake_images, 28 * 28, device=autoencoder.device)
embeddings = encoder_model(fake_image_batch)
return embeddings.tolist()
In [ ]:
Copied!
# Create embeddings
embeddings = create_embed(
NumberFakeImages(n_fake_images=2),
encoder,
decoder,
checkpoint_path="./lightning_logs/LitAutoEncoder/version_0/checkpoints/epoch=0-step=100.ckpt",
)
# Print the embeddings
print("⚡" * 20, "\nPredictions (image embeddings):\n", embeddings, "\n", "⚡" * 20, sep="")
# Create embeddings
embeddings = create_embed(
NumberFakeImages(n_fake_images=2),
encoder,
decoder,
checkpoint_path="./lightning_logs/LitAutoEncoder/version_0/checkpoints/epoch=0-step=100.ckpt",
)
# Print the embeddings
print("⚡" * 20, "\nPredictions (image embeddings):\n", embeddings, "\n", "⚡" * 20, sep="")
⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ Predictions (image embeddings): [[0.3581562042236328, 0.14595450460910797, 0.47827404737472534], [0.5079143643379211, 0.09660176187753677, 0.5766311287879944]] ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
In [ ]:
Copied!