Skip to content

FastAPI App

uv_datascience_project_template.app_fastapi_autoencoder

embed(input_data)

Embed fake images using the trained autoencoder.

PARAMETER DESCRIPTION
input_data

Input data containing the number of fake images to embed.

TYPE: NumberFakeImages

RETURNS DESCRIPTION
dict[str, Any]

dict[str, Any]: A dictionary containing the embeddings of the fake images.

Source code in src/uv_datascience_project_template/app_fastapi_autoencoder.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@app.post("/embed")
def embed(input_data: NumberFakeImages) -> dict[str, Any]:
    """Embed fake images using the trained autoencoder.

    Args:
        input_data (NumberFakeImages): Input data containing the number of fake images to embed.

    Returns:
        dict[str, Any]: A dictionary containing the embeddings of the fake images.
    """
    if app.state.encoder is None or app.state.decoder is None:
        raise HTTPException(
            status_code=500, detail="Model not initialized. Train the model first."
        )

    n_fake_images = input_data.n_fake_images

    if not app.state.checkpoint_path or not os.path.exists(app.state.checkpoint_path):
        raise HTTPException(
            status_code=500, detail="Checkpoint file not found. Train the model first."
        )

    # Load the trained autoencoder from the checkpoint
    autoencoder = LitAutoEncoder.load_from_checkpoint(app.state.checkpoint_path)
    encoder_model = autoencoder.encoder
    encoder_model.eval()

    # Generate fake image embeddings based on user input
    # fake_image_batch = rand(n_fake_images, 28 * 28, device=autoencoder.device)
    fake_image_batch = rand(n_fake_images, settings.model.input_dim, device=autoencoder.device)
    embeddings = encoder_model(fake_image_batch)
    # print("⚡" * 20, "\nPredictions (image embeddings):\n", embeddings, "\n", "⚡" * 20)

    return {"embeddings": embeddings.tolist()}

read_root()

Root endpoint that provides information about the API.

Source code in src/uv_datascience_project_template/app_fastapi_autoencoder.py
33
34
35
36
37
38
39
40
41
42
43
@app.get("/")
def read_root() -> Response:
    """Root endpoint that provides information about the API."""

    message = """
    ⚡⚡⚡ Welcome to the LitAutoEncoder API! ⚡⚡⚡
    - To train the model, send a POST request to '/train' without providing any additional input.
    - To get encodings for random fake images, POST to '/embed' with JSON input:
      {'n_fake_images': [1-10]} in the request body.
    """
    return Response(content=message, media_type="text/plain")

startup_event() async

Initialize FastAPI app state on startup.

Source code in src/uv_datascience_project_template/app_fastapi_autoencoder.py
18
19
20
21
22
23
24
@app.on_event("startup")
async def startup_event():
    """Initialize FastAPI app state on startup."""
    app.state.encoder = None
    app.state.decoder = None
    app.state.is_model_trained = False
    app.state.checkpoint_path = None

train_model()

Train the autoencoder model.

RETURNS DESCRIPTION
dict[str, str]

dict[str, str]: A message indicating the training status.

Source code in src/uv_datascience_project_template/app_fastapi_autoencoder.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@app.post("/train")
def train_model() -> dict[str, str]:
    """Train the autoencoder model.

    Returns:
        dict[str, str]: A message indicating the training status.
    """
    if app.state.is_model_trained:
        return {"message": "Model is already trained."}

    encoder, decoder, is_model_trained, checkpoint_path = train_litautoencoder(settings)
    app.state.encoder = encoder
    app.state.decoder = decoder
    app.state.is_model_trained = is_model_trained
    app.state.checkpoint_path = checkpoint_path
    return {"message": "Model training completed successfully."}