Cree su aplicación web para identificar dígitos usando Python, Tensorflow, Lime, Umap, Sklearn y Streamlit – Parte 2

Índice

Introducción

Esta sección es la segunda parte de:

Cree su aplicación web para identificar dígitos usando Python, Tensorflow, Lime, Umap, Sklearn y Streamlit

En la anterior publicación entrenamos nuestros modelos de Machine Learning que nos ayudan no solo a identificar los dígitos escritos a mano, también nos ayuda a poder entender las reglas que ha tenido en cuenta para predecir si un dígito es, por ejemplo, 0 o 1, …

En esta publicación crearemos nuestra aplicación web usando StreamLit.

Preparando las funciones

En esta sección vamos a preparar nuestras funciones para nuestra Aplicación Web:

Importando librerías

Empezamos importando las librerías necesarias:

  • pandas: para leer nuestra tabla de datos.
  • numpy: para realizar operaciones con matrices.
  • cv2: para tratamiento de imágenes.
  • Tensorflow: para la predicción de nuestro modelo Deep Learning.
  • stramlit: para crear nuestra aplicación web.
  • streamlit_drawable_canvas: para poder graficar en nuestra aplicación web.
  • sklearn: para usar los modelos de NearestNeighbors.
  • skimage: para añadir opciones de gráficos.
  • seaborn y matplotlib: para gráficos.
  • pickle y dill: para leer nuestros modelos guardados.
  • config: en este fichero se encuentra la ruta donde hemos guardado nuestros modelos.
import pandas as pd
import numpy as np
import cv2
import tensorflow as tf
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from sklearn.model_selection import train_test_split
from skimage.segmentation import mark_boundaries
from skimage.color import label2rgb
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import dill

import config


st.set_option('deprecation.showPyplotGlobalUse', False)

Pasos previos

Antes de usar nuestras funciones necesitamos definir algunos objetos. Estos objetos almacenarán nuestros modelos de Machine Learning e información importante, como DataFrame para graficar los embeddings, y nuestras matrices para graficar los dígitos:

model = None
emb_model = None
X_train = None
y_train = None
nearest_model = None
explainer = None
segmenter = None
umap_model = None
umap_train_df = None

Funciones para tratamiento de imágenes

Las siguientes funciones realizan lo siguiente:

  • to_rgb: nuestros dígitos solo tienen un canal, este canal es por la escala de grises. Para eso añadimos dos canales más, para que nuestros dígitos tengan 3 canales: RGB y lo podamos graficar y usar las opciones de LIME.
  • prepara_img: recibe una imagen que tiene dimensión 192 x 192 y lo transforma a una dimensiones de 28×28, luego añade los 2 canales faltantes haciendo uso de la función to_rgb.
# PREPARANDO IMAGEN

def to_rgb(x):
    ''' Convertimos una imagen de escala de grises a RGB'''
    x_rgb = np.zeros((x.shape[0], 28, 28, 3))
    for i in range(3):
        x_rgb[..., i] = x[..., 0]
    return x_rgb.reshape(-1, 28, 28, 3)


def prepara_img(image_array):
    ''' preparamos una imagen para predecir el dígito '''
    img = cv2.resize(image_array, (28, 28))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img.reshape(-1, 28, 28, 1)
    return to_rgb(img)

Predicción de dígitos

Empezamos con las predicciones de nuestros modelo Deep Learning. Nuestro modelo ha sido entrenado en nuestra publicación anterior.

Nuestro modelo recibe como entrada una matriz de tamaño 28 x 28 x 3 y devuelve la probabilidad de pertenecer a cada digito. La siguiente imagen describe el proceso de predicción:

Luego las probabilidades se van a graficar de la siguiente manera:

# CARGANDO Y PREDICIENDO LOS DIGITOS

def load_model():
    ''' Cargando el modelo RED NEURONAL '''
    global model
    global emb_model
    if model is None or emb_model is None:
        model_file = open(config.MODEL_PATH_JSON, 'r')
        model = model_file.read()
        model_file.close()
        model = tf.keras.models.model_from_json(model)
        model.load_weights(config.MODEL_PATH_H5)

        emb_model = tf.keras.models.Model(model.input,
                                          model.get_layer('embedding').output)
    return model, emb_model


def predict_class(img):
    ''' Calculamos y graficamos las predicciones '''
    global model
    global emb_model
    model, emb_model = load_model()
    predictions = model.predict(img)
    predictions = predictions.ravel()
    clase_predicha = int(predictions.argmax())
    prob_ = 100 * predictions[clase_predicha]
    pred_df = pd.DataFrame(predictions, index=range(10))
    st.subheader(f'Clase predicha: {clase_predicha}, probabilidad: {prob_:.2f}')
    st.bar_chart(pred_df)

Lime para entender las predicciones

Si bien nuestro modelo nos predice con poco error el dígito que hemos escrito, es interesante saber qué pixeles ha tenido en cuenta para hacer la predicción.

Un modelo Deep Learning es una caja negra, es decir, aunque tenga buena precisión, es dificil entender lo que aprendió.

Para poder entender las reglas de nuestro modelo, haremos uso de la librería Lime. Esta librería trata de explicar qué están haciendo  nuestros modelos de Machine Learning.

La siguiente imagen explica el funcionamiento de esta librería. Para más información puedes ir al siguiente enlace:

Tenemos que tener en cuenta que la explicación es para una imagen en particular. Esto viene bien cuando queremos analizar los dígitos mal clasificados y tratar de entender por qué han sido mal clasificados.

La salida del algoritmo LIME se van a graficar de la siguiente manera.

En el primer gráfico podemos ver el dígito escrito y distintas reglas para cada clase, desde 0 a 9. Las reglas se pueden ver de color rojizo.

En el segundo gráfico podemos ver el dígito escrito y las reglas, de color amarillo, que ha tenido en cuenta nuestro modelo para la clase predicha.

# DECISIONES USANDO LIME

def load_explainer():
    ''' Cargando el modelo LIME '''
    global explainer
    global segmenter
    if explainer is None or segmenter is None:
        explainer = dill.load(open(config.MODEL_EXPLAINER, 'rb'))
        segmenter = dill.load(open(config.MODEL_SEGMENTER, 'rb'))
    return explainer, segmenter


def plot_rules(img):
    ''' Obtenemos las reglas de decisión usando LIME '''
    global model
    global emb_model
    global explainer
    global segmenter
    model, emb_model = load_model()
    explainer, segmenter = load_explainer()

    X_eval = img.reshape(28, 28, 3)

    explanation = explainer.explain_instance(X_eval,
                                             classifier_fn=model.predict,
                                             top_labels=10,
                                             hide_color=0,
                                             num_samples=100,
                                             segmentation_fn=segmenter)

    plt.figure(figsize=(15, 10))

    for i in range(10):
        temp, mask = explanation.get_image_and_mask(i,
                                                    positive_only=True,
                                                    num_features=1000,
                                                    hide_rest=False,
                                                    min_weight=0.01)
        plt.subplot(2, 5, (i + 1))
        plt.imshow(label2rgb(mask.astype(np.uint8),
                             X_eval.astype(np.uint8),
                             bg_label=0),
                   interpolation='nearest')
        plt.title(f'Positivo para clase: {i}')
        plt.axis('off')
    plt.axis('off')

    st.pyplot()

    clase_predicha = model.predict(X_eval.reshape((1, 28, 28, 3))).argmax(axis=1)[0]
    image, mask = explanation.get_image_and_mask(clase_predicha,
                                                 positive_only=True,
                                                 hide_rest=False)
    plt.imshow(X_eval.astype(np.uint8))
    plt.imshow(mark_boundaries(image.astype(np.uint8), mask))
    plt.title(f'Decisiones para la clase predicha: {clase_predicha}')
    plt.axis('off')
    st.pyplot()

Umap para reducción de dimensiones

A continuación vamos a usar un modelo para reducir dimensiones. Este modelo nos va a ayudar a visualizar nuestros dígitos, que tienen muchas dimensiones, a solo 2 dimensiones. Este modelo trata de mantener las distancias que hay entre los dígitos, dígitos cercanos estarán representados en puntos cercanos. Este modelo se llama UMAP.

Hay más librerías para reducir dimensiones, por ejemplo, PCA, TSNE, UMAP, etc. He decidido usar UMAP ya que es útil para visualización de datos y el modelo se puede usar para predecir nuevos dígitos.

Para más información, pueden ir al siguiente enlace.

La salida de este modelo es el siguiente gráfico. En este gráfico se ubican cada dígito del conjunto de entrenamiento, cada color indica la clase real y el punto de color negro, la representación del número que estamos prediciendo.

Si, por ejemplo, vemos que nuestro punto está dentro de una clase, podemos decir que el modelo está seguro de su predicción, pero si está en los bordes o, quizás, en medio de dos clases, quiere decir que nuestro modelo no está seguro de su predicción:

# REDUCCIÓN DE DIMENSIONES USANDO UMAP

def load_umap():
    '''Cargando el modelo UMAP '''
    global umap_model
    global umap_train_df
    if umap_model is None or umap_train_df is None:
        umap_model = pickle.load(open(config.MODEL_UMAP, 'rb'))
        umap_train_df = pd.read_csv(config.UMAP_TRAIN)
    return umap_model, umap_train_df


def plot_umap(img):
    ''' Reducción de dimensiones usando UMAP '''
    global model
    global emb_model
    global umap_model
    global umap_train_df
    model, emb_model = load_model()
    umap_model, umap_train_df = load_umap()
    emb_test = emb_model.predict(img.reshape(-1, 28, 28, 3), verbose=0)
    emb_test = emb_test.reshape(1, -1)
    umap_test = umap_model.transform(emb_test)
    umap_test = umap_test.reshape(1, -1)

    plt.figure(figsize=(12, 10))
    sns.scatterplot(x='x0', y='x1', alpha=0.1, hue='target',
                    legend='full', data=umap_train_df,
                    palette='Paired_r')
    plt.scatter(umap_test[0, 0], umap_test[0, 1], s=100, c='k')
    st.pyplot()

NearestNeighbors para búsqueda de imágenes similares

Y, por último, pero no menos importante, podemos buscar dígitos parecidos, de nuestro conjunto de entrenamiento, al dígito que hemos escrito.

Esto lo hacemos mediante NearestNeighbors.

Haciendo uso de la distancia coseno, podemos buscar los dígitos más cercanos al dígito que hemos escrito. Esto nos puede ayudar a entender también si a nuestro modelo le ha costado predecir el dígito correctamente.

Para más información de este modelo, pueden ir al siguiente enlace.

El siguiente gráfico muestra un ejemplo de este modelo. Para el dígito escrito, muestra 5 dígitos más cercanos y la clase real (escrito como el título). Si al modelo no le ha costado predecir la clase, los dígitos similares deberían ser de la misma clase. Si al modelo si le ha costado predecir la clase, los dígitos similares probablemente pertenecerán a distintas clases:

# IMAGENES SIMILARES USANDO NearestNeighbors

def load_data():
    ''' Cargando los datos '''
    global X_train
    global y_train
    if X_train is None or y_train is None:
        (X, y), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
        X = X.reshape(-1, 28, 28, 1)
        X_train, X_val, y_train, y_test = train_test_split(X, y,
                                                           test_size=0.2,
                                                           random_state=123456)
        X_train = to_rgb(X_train)
        X_train = X_train.reshape(-1, 28, 28, 3)
    return X_train, y_train


def load_nearest_model():
    ''' Cargando modelo de vecinos cercanos '''
    global nearest_model
    if nearest_model is None:
        nearest_model = pickle.load(open(config.MODEL_SKLEARN_NN, 'rb'))
    return nearest_model


def plot_similares(img):
    ''' Imágenes similares usando NearestNeighbors '''
    global model
    global emb_model
    global X_train
    global y_train
    global nearest_model
    model, emb_model = load_model()
    X_train, y_train = load_data()
    nearest_model = load_nearest_model()

    emb_test = emb_model.predict(img.reshape(1, 28, 28, 3), verbose=0)
    emb_test = emb_test.reshape(1, -1)

    distances_test, pred_nearest_test = nearest_model.kneighbors(emb_test)

    pred_nearest_test = pred_nearest_test.ravel()

    plt.figure(figsize=(15, 10))

    for i in range(nearest_model.n_neighbors):
        plt.subplot(1, nearest_model.n_neighbors, (i + 1))
        plt.imshow(X_train[pred_nearest_test[i]].astype(np.uint8),
                   cmap=plt.cm.binary)
        plt.title(y_train[pred_nearest_test[i]])
        plt.axis('off')
    plt.axis('off')
    st.pyplot()

Creando nuestra aplicación web

Por último, haremos uso de la librería STREAMLIT para poner nuestras funciones dentro de una aplicación web.

El código es el siguiente. Puede ver que los pasos son muy sencillos.

st.title('¡RECONOCE DÍGITOS!')

st.markdown('''
La siguiente aplicación intenta predecir el dígito escrito.
* Usamos redes neuronales convolucionales con Tensorflow.
* Para identificar reglas de decisión usamos LIME.
* Para reducir dimensiones usamos UMAP.
* Para buscar imágenes similares usamos NearestNeighbors.
''')

st.markdown('''¡ESCRIBA UN DÍGITO, INTENTARÉ PREDECIRLO!''')

canvas_result = st_canvas(
    fill_color='#000000',
    stroke_width=20,
    stroke_color='#FFFFFF',
    background_color='#000000',
    width=config.SIZE_DRAW,
    height=config.SIZE_DRAW,
    drawing_mode='freedraw',
    key='canvas'
)


if canvas_result.image_data is not None:
    if st.button('PREDECIR'):
        image_array = canvas_result.image_data.astype(np.uint8)
        img = prepara_img(image_array=image_array)

        st.subheader('PREDICCIÓN')
        predict_class(img)

        st.subheader('LIME PARA REGLAS DE DECISIÓN')
        plot_rules(img)

        st.subheader('UMAP PARA REDUCCIÓN DE DIMENSIONES')
        plot_umap(img)

        st.subheader('VECINOS CERCANOS PARA BUSCAR IMÁGENES SIMILARES EN EL CONJUNTO DE ENTRENAMIENTO')
        plot_similares(img)

Probando nuestra Aplicación Web

Y ¡Listo!, tenemos nuestra aplicación web.

Si quieren ver el código completo, pueden ir al siguiente enlace:

https://github.com/Jazielinho/digit_recognition/blob/master/digit_app.py

Para poder ejecutar la aplicación web, tenemos que escribir lo siguiente:


streamlit run digit_app.py

Y eso es todo, automáticamente aparecerá una pestaña web donde podrán utilizar la aplicación.

El siguiente vídeo de YouTube muestra la aplicación web trabajando:

Y eso es todo.

¡Cualquier comentario es bienvenido!

close

¡No te pierdas mis últimas publicaciones!

¡No te enviaré spam!

Una respuesta a «Cree su aplicación web para identificar dígitos usando Python, Tensorflow, Lime, Umap, Sklearn y Streamlit – Parte 2»

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *