Skip to content

FastAPI App

embed(input_data)

Embed fake images using the trained autoencoder.

Parameters:

Name Type Description Default
input_data NumberFakeImages

Input data containing the number of fake images to embed.

required

Returns:

Type Description
dict[str, Any]

dict[str, Any]: A dictionary containing the embeddings of the fake images.

Source code in src/uv_datascience_project_template/app_fastapi_autoencoder.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@app.post("/embed")
def embed(input_data: NumberFakeImages) -> dict[str, Any]:
    """Embed fake images using the trained autoencoder.

    Args:
        input_data (NumberFakeImages): Input data containing the number of fake images to embed.

    Returns:
        dict[str, Any]: A dictionary containing the embeddings of the fake images.
    """
    global encoder, decoder

    if encoder is None or decoder is None:
        raise HTTPException(
            status_code=500, detail="Model not initialized. Train the model first."
        )

    n_fake_images = input_data.n_fake_images
    checkpoint_path = "./lightning_logs/LitAutoEncoder/version_0/checkpoints/epoch=0-step=100.ckpt"

    if not os.path.exists(checkpoint_path):
        raise HTTPException(
            status_code=500, detail="Checkpoint file not found. Train the model first."
        )

    # Load the trained autoencoder from the checkpoint
    autoencoder = LitAutoEncoder.load_from_checkpoint(
        checkpoint_path, encoder=encoder, decoder=decoder
    )
    encoder_model = autoencoder.encoder
    encoder_model.eval()

    # Generate fake image embeddings based on user input
    fake_image_batch = rand(n_fake_images, 28 * 28, device=autoencoder.device)
    embeddings = encoder_model(fake_image_batch)
    # print("⚡" * 20, "\nPredictions (image embeddings):\n", embeddings, "\n", "⚡" * 20)

    return {"embeddings": embeddings.tolist()}

read_root()

Root endpoint that provides information about the API.

Source code in src/uv_datascience_project_template/app_fastapi_autoencoder.py
26
27
28
29
30
31
32
33
34
35
36
@app.get("/")
def read_root() -> Response:
    """Root endpoint that provides information about the API."""

    message = """
    ⚡⚡⚡ Welcome to the LitAutoEncoder API! ⚡⚡⚡
    - To train the model, send a POST request to '/train' without providing any additional input.
    - To get encodings for random fake images, POST to '/embed' with JSON input: 
      {'n_fake_images': [1-10]} in the request body.
    """
    return Response(content=message, media_type="text/plain")

train_model()

Train the autoencoder model.

Returns:

Type Description
dict[str, str]

dict[str, str]: A message indicating the training status.

Source code in src/uv_datascience_project_template/app_fastapi_autoencoder.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@app.post("/train")
def train_model() -> dict[str, str]:
    """Train the autoencoder model.

    Returns:
        dict[str, str]: A message indicating the training status.
    """
    global encoder, decoder, is_model_trained

    if is_model_trained:
        return {"message": "Model is already trained."}

    encoder, decoder, is_model_trained = train_litautoencoder()
    return {"message": "Model training completed successfully."}