Cree su aplicación web para identificar dígitos usando Python, Tensorflow, Lime, Umap, Sklearn y Streamlit – Parte 2
Índice
Introducción
Esta sección es la segunda parte de:
En la anterior publicación entrenamos nuestros modelos de Machine Learning que nos ayudan no solo a identificar los dígitos escritos a mano, también nos ayuda a poder entender las reglas que ha tenido en cuenta para predecir si un dígito es, por ejemplo, 0 o 1, …
En esta publicación crearemos nuestra aplicación web usando StreamLit.
Preparando las funciones
En esta sección vamos a preparar nuestras funciones para nuestra Aplicación Web:
Importando librerías
Empezamos importando las librerías necesarias:
- pandas: para leer nuestra tabla de datos.
- numpy: para realizar operaciones con matrices.
- cv2: para tratamiento de imágenes.
- Tensorflow: para la predicción de nuestro modelo Deep Learning.
- stramlit: para crear nuestra aplicación web.
- streamlit_drawable_canvas: para poder graficar en nuestra aplicación web.
- sklearn: para usar los modelos de NearestNeighbors.
- skimage: para añadir opciones de gráficos.
- seaborn y matplotlib: para gráficos.
- pickle y dill: para leer nuestros modelos guardados.
- config: en este fichero se encuentra la ruta donde hemos guardado nuestros modelos.
import pandas as pd import numpy as np import cv2 import tensorflow as tf import streamlit as st from streamlit_drawable_canvas import st_canvas from sklearn.model_selection import train_test_split from skimage.segmentation import mark_boundaries from skimage.color import label2rgb import matplotlib.pyplot as plt import seaborn as sns import pickle import dill import config st.set_option('deprecation.showPyplotGlobalUse', False)
Pasos previos
Antes de usar nuestras funciones necesitamos definir algunos objetos. Estos objetos almacenarán nuestros modelos de Machine Learning e información importante, como DataFrame para graficar los embeddings, y nuestras matrices para graficar los dígitos:
model = None emb_model = None X_train = None y_train = None nearest_model = None explainer = None segmenter = None umap_model = None umap_train_df = None
Funciones para tratamiento de imágenes
Las siguientes funciones realizan lo siguiente:
- to_rgb: nuestros dígitos solo tienen un canal, este canal es por la escala de grises. Para eso añadimos dos canales más, para que nuestros dígitos tengan 3 canales: RGB y lo podamos graficar y usar las opciones de LIME.
- prepara_img: recibe una imagen que tiene dimensión 192 x 192 y lo transforma a una dimensiones de 28×28, luego añade los 2 canales faltantes haciendo uso de la función to_rgb.
# PREPARANDO IMAGEN def to_rgb(x): ''' Convertimos una imagen de escala de grises a RGB''' x_rgb = np.zeros((x.shape[0], 28, 28, 3)) for i in range(3): x_rgb[..., i] = x[..., 0] return x_rgb.reshape(-1, 28, 28, 3) def prepara_img(image_array): ''' preparamos una imagen para predecir el dígito ''' img = cv2.resize(image_array, (28, 28)) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = img.reshape(-1, 28, 28, 1) return to_rgb(img)
Predicción de dígitos
Empezamos con las predicciones de nuestros modelo Deep Learning. Nuestro modelo ha sido entrenado en nuestra publicación anterior.
Nuestro modelo recibe como entrada una matriz de tamaño 28 x 28 x 3 y devuelve la probabilidad de pertenecer a cada digito. La siguiente imagen describe el proceso de predicción:
Luego las probabilidades se van a graficar de la siguiente manera:
# CARGANDO Y PREDICIENDO LOS DIGITOS def load_model(): ''' Cargando el modelo RED NEURONAL ''' global model global emb_model if model is None or emb_model is None: model_file = open(config.MODEL_PATH_JSON, 'r') model = model_file.read() model_file.close() model = tf.keras.models.model_from_json(model) model.load_weights(config.MODEL_PATH_H5) emb_model = tf.keras.models.Model(model.input, model.get_layer('embedding').output) return model, emb_model def predict_class(img): ''' Calculamos y graficamos las predicciones ''' global model global emb_model model, emb_model = load_model() predictions = model.predict(img) predictions = predictions.ravel() clase_predicha = int(predictions.argmax()) prob_ = 100 * predictions[clase_predicha] pred_df = pd.DataFrame(predictions, index=range(10)) st.subheader(f'Clase predicha: {clase_predicha}, probabilidad: {prob_:.2f}') st.bar_chart(pred_df)
Lime para entender las predicciones
Si bien nuestro modelo nos predice con poco error el dígito que hemos escrito, es interesante saber qué pixeles ha tenido en cuenta para hacer la predicción.
Un modelo Deep Learning es una caja negra, es decir, aunque tenga buena precisión, es dificil entender lo que aprendió.
Para poder entender las reglas de nuestro modelo, haremos uso de la librería Lime. Esta librería trata de explicar qué están haciendo nuestros modelos de Machine Learning.
La siguiente imagen explica el funcionamiento de esta librería. Para más información puedes ir al siguiente enlace:
Tenemos que tener en cuenta que la explicación es para una imagen en particular. Esto viene bien cuando queremos analizar los dígitos mal clasificados y tratar de entender por qué han sido mal clasificados.
La salida del algoritmo LIME se van a graficar de la siguiente manera.
En el primer gráfico podemos ver el dígito escrito y distintas reglas para cada clase, desde 0 a 9. Las reglas se pueden ver de color rojizo.
En el segundo gráfico podemos ver el dígito escrito y las reglas, de color amarillo, que ha tenido en cuenta nuestro modelo para la clase predicha.
# DECISIONES USANDO LIME def load_explainer(): ''' Cargando el modelo LIME ''' global explainer global segmenter if explainer is None or segmenter is None: explainer = dill.load(open(config.MODEL_EXPLAINER, 'rb')) segmenter = dill.load(open(config.MODEL_SEGMENTER, 'rb')) return explainer, segmenter def plot_rules(img): ''' Obtenemos las reglas de decisión usando LIME ''' global model global emb_model global explainer global segmenter model, emb_model = load_model() explainer, segmenter = load_explainer() X_eval = img.reshape(28, 28, 3) explanation = explainer.explain_instance(X_eval, classifier_fn=model.predict, top_labels=10, hide_color=0, num_samples=100, segmentation_fn=segmenter) plt.figure(figsize=(15, 10)) for i in range(10): temp, mask = explanation.get_image_and_mask(i, positive_only=True, num_features=1000, hide_rest=False, min_weight=0.01) plt.subplot(2, 5, (i + 1)) plt.imshow(label2rgb(mask.astype(np.uint8), X_eval.astype(np.uint8), bg_label=0), interpolation='nearest') plt.title(f'Positivo para clase: {i}') plt.axis('off') plt.axis('off') st.pyplot() clase_predicha = model.predict(X_eval.reshape((1, 28, 28, 3))).argmax(axis=1)[0] image, mask = explanation.get_image_and_mask(clase_predicha, positive_only=True, hide_rest=False) plt.imshow(X_eval.astype(np.uint8)) plt.imshow(mark_boundaries(image.astype(np.uint8), mask)) plt.title(f'Decisiones para la clase predicha: {clase_predicha}') plt.axis('off') st.pyplot()
Umap para reducción de dimensiones
A continuación vamos a usar un modelo para reducir dimensiones. Este modelo nos va a ayudar a visualizar nuestros dígitos, que tienen muchas dimensiones, a solo 2 dimensiones. Este modelo trata de mantener las distancias que hay entre los dígitos, dígitos cercanos estarán representados en puntos cercanos. Este modelo se llama UMAP.
Hay más librerías para reducir dimensiones, por ejemplo, PCA, TSNE, UMAP, etc. He decidido usar UMAP ya que es útil para visualización de datos y el modelo se puede usar para predecir nuevos dígitos.
Para más información, pueden ir al siguiente enlace.
La salida de este modelo es el siguiente gráfico. En este gráfico se ubican cada dígito del conjunto de entrenamiento, cada color indica la clase real y el punto de color negro, la representación del número que estamos prediciendo.
Si, por ejemplo, vemos que nuestro punto está dentro de una clase, podemos decir que el modelo está seguro de su predicción, pero si está en los bordes o, quizás, en medio de dos clases, quiere decir que nuestro modelo no está seguro de su predicción:
# REDUCCIÓN DE DIMENSIONES USANDO UMAP def load_umap(): '''Cargando el modelo UMAP ''' global umap_model global umap_train_df if umap_model is None or umap_train_df is None: umap_model = pickle.load(open(config.MODEL_UMAP, 'rb')) umap_train_df = pd.read_csv(config.UMAP_TRAIN) return umap_model, umap_train_df def plot_umap(img): ''' Reducción de dimensiones usando UMAP ''' global model global emb_model global umap_model global umap_train_df model, emb_model = load_model() umap_model, umap_train_df = load_umap() emb_test = emb_model.predict(img.reshape(-1, 28, 28, 3), verbose=0) emb_test = emb_test.reshape(1, -1) umap_test = umap_model.transform(emb_test) umap_test = umap_test.reshape(1, -1) plt.figure(figsize=(12, 10)) sns.scatterplot(x='x0', y='x1', alpha=0.1, hue='target', legend='full', data=umap_train_df, palette='Paired_r') plt.scatter(umap_test[0, 0], umap_test[0, 1], s=100, c='k') st.pyplot()
NearestNeighbors para búsqueda de imágenes similares
Y, por último, pero no menos importante, podemos buscar dígitos parecidos, de nuestro conjunto de entrenamiento, al dígito que hemos escrito.
Esto lo hacemos mediante NearestNeighbors.
Haciendo uso de la distancia coseno, podemos buscar los dígitos más cercanos al dígito que hemos escrito. Esto nos puede ayudar a entender también si a nuestro modelo le ha costado predecir el dígito correctamente.
Para más información de este modelo, pueden ir al siguiente enlace.
El siguiente gráfico muestra un ejemplo de este modelo. Para el dígito escrito, muestra 5 dígitos más cercanos y la clase real (escrito como el título). Si al modelo no le ha costado predecir la clase, los dígitos similares deberían ser de la misma clase. Si al modelo si le ha costado predecir la clase, los dígitos similares probablemente pertenecerán a distintas clases:
# IMAGENES SIMILARES USANDO NearestNeighbors def load_data(): ''' Cargando los datos ''' global X_train global y_train if X_train is None or y_train is None: (X, y), (X_test, y_test) = tf.keras.datasets.mnist.load_data() X = X.reshape(-1, 28, 28, 1) X_train, X_val, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123456) X_train = to_rgb(X_train) X_train = X_train.reshape(-1, 28, 28, 3) return X_train, y_train def load_nearest_model(): ''' Cargando modelo de vecinos cercanos ''' global nearest_model if nearest_model is None: nearest_model = pickle.load(open(config.MODEL_SKLEARN_NN, 'rb')) return nearest_model def plot_similares(img): ''' Imágenes similares usando NearestNeighbors ''' global model global emb_model global X_train global y_train global nearest_model model, emb_model = load_model() X_train, y_train = load_data() nearest_model = load_nearest_model() emb_test = emb_model.predict(img.reshape(1, 28, 28, 3), verbose=0) emb_test = emb_test.reshape(1, -1) distances_test, pred_nearest_test = nearest_model.kneighbors(emb_test) pred_nearest_test = pred_nearest_test.ravel() plt.figure(figsize=(15, 10)) for i in range(nearest_model.n_neighbors): plt.subplot(1, nearest_model.n_neighbors, (i + 1)) plt.imshow(X_train[pred_nearest_test[i]].astype(np.uint8), cmap=plt.cm.binary) plt.title(y_train[pred_nearest_test[i]]) plt.axis('off') plt.axis('off') st.pyplot()
Creando nuestra aplicación web
Por último, haremos uso de la librería STREAMLIT para poner nuestras funciones dentro de una aplicación web.
El código es el siguiente. Puede ver que los pasos son muy sencillos.
st.title('¡RECONOCE DÍGITOS!') st.markdown(''' La siguiente aplicación intenta predecir el dígito escrito. * Usamos redes neuronales convolucionales con Tensorflow. * Para identificar reglas de decisión usamos LIME. * Para reducir dimensiones usamos UMAP. * Para buscar imágenes similares usamos NearestNeighbors. ''') st.markdown('''¡ESCRIBA UN DÍGITO, INTENTARÉ PREDECIRLO!''') canvas_result = st_canvas( fill_color='#000000', stroke_width=20, stroke_color='#FFFFFF', background_color='#000000', width=config.SIZE_DRAW, height=config.SIZE_DRAW, drawing_mode='freedraw', key='canvas' ) if canvas_result.image_data is not None: if st.button('PREDECIR'): image_array = canvas_result.image_data.astype(np.uint8) img = prepara_img(image_array=image_array) st.subheader('PREDICCIÓN') predict_class(img) st.subheader('LIME PARA REGLAS DE DECISIÓN') plot_rules(img) st.subheader('UMAP PARA REDUCCIÓN DE DIMENSIONES') plot_umap(img) st.subheader('VECINOS CERCANOS PARA BUSCAR IMÁGENES SIMILARES EN EL CONJUNTO DE ENTRENAMIENTO') plot_similares(img)
Probando nuestra Aplicación Web
Y ¡Listo!, tenemos nuestra aplicación web.
Si quieren ver el código completo, pueden ir al siguiente enlace:
https://github.com/Jazielinho/digit_recognition/blob/master/digit_app.py
Para poder ejecutar la aplicación web, tenemos que escribir lo siguiente:
streamlit run digit_app.py
Y eso es todo, automáticamente aparecerá una pestaña web donde podrán utilizar la aplicación.
El siguiente vídeo de YouTube muestra la aplicación web trabajando:
Y eso es todo.
¡Cualquier comentario es bienvenido!
Una respuesta a «Cree su aplicación web para identificar dígitos usando Python, Tensorflow, Lime, Umap, Sklearn y Streamlit – Parte 2»