Enhancements to process images through PIL and also add RESNET50B

This commit is contained in:
2024-08-01 17:28:00 -04:00
parent c568cb8aaa
commit e69c8cd287
2 changed files with 72 additions and 23 deletions

View File

@@ -36,20 +36,34 @@
import keras import keras
from keras.preprocessing.image import ImageDataGenerator from keras.preprocessing.image import ImageDataGenerator
from flask import Flask, render_template, request from flask import Flask, render_template, send_file, request
import numpy as np import numpy as np
import tensorflow import tensorflow
import sys import sys
import os import os
import io import io
import PIL.Image import PIL.Image
import uuid
#init flaskapp #init flaskapp
app = Flask(__name__) app = Flask(__name__)
# def save_image(imgData): # use this on image passed in to the API
# with open('output.jpg','wb') as output: def save_image(imgData):
# output.write(imgData) filename=str(uuid.uuid4())+'.jpg'
print('saving {filename}'.format(filename=filename))
with open(filename,'wb') as output:
output.write(imgData)
# Use this after processing with PIL
def save_pil_image(image):
filename=str(uuid.uuid4())+'.jpg'
print('saving {filename}'.format(filename=filename))
imageByteArray = io.BytesIO()
image.save(imageByteArray, format='JPEG')
imageByteArray = imageByteArray.getvalue()
with open(filename,'wb') as output:
output.write(imageByteArray)
def render_index(): def render_index():
@@ -76,21 +90,22 @@ def index():
def ping(): def ping():
return "Alive" return "Alive"
# @app.route('/predict_vgg16', methods=['GET','POST']) @app.route('/predict_vgg16', methods=['GET','POST'])
# def predict_vgg16(): def predict_vgg16():
# test_image=request.get_data() test_image=request.get_data()
# test_image = PIL.Image.open(io.BytesIO(test_image)) test_image = PIL.Image.open(io.BytesIO(test_image))
# save_pil_image(test_image)
# test_image = test_image.convert('L') # test_image = test_image.convert('L')
# test_array=keras.preprocessing.image.img_to_array(test_image) test_array=keras.preprocessing.image.img_to_array(test_image)
# batch_test_array=np.array([test_array]) batch_test_array=np.array([test_array])
# predictions=vgg16_model.predict(batch_test_array) predictions=vgg16_model.predict(batch_test_array)
# if type(predictions) == list: if type(predictions) == list:
# average_prediction = sum(predictions)/len(predictions) average_prediction = sum(predictions)/len(predictions)
# threshold_output = np.where(average_prediction > 0.5, 1, 0) threshold_output = np.where(average_prediction > 0.5, 1, 0)
# else : else :
# threshold_output = np.where(predictions > 0.5, 1, 0) threshold_output = np.where(predictions > 0.5, 1, 0)
# response=str(predictions)+'-->'+str(threshold_output) response=str(predictions)+'-->'+str(threshold_output)
# return response return response
@app.route('/predict_resnet50', methods=['GET','POST']) @app.route('/predict_resnet50', methods=['GET','POST'])
def predict_resnet50(): def predict_resnet50():
@@ -109,6 +124,24 @@ def predict_resnet50():
response=str(predictions)+'-->'+str(threshold_output) response=str(predictions)+'-->'+str(threshold_output)
return response return response
# This version expects the image to be of the form (x,x,3).
@app.route('/predict_resnet50B', methods=['GET','POST'])
def predict_resnet50B():
print('/predict_resnet50B')
test_image=request.get_data()
save_image(test_image)
test_image = PIL.Image.open(io.BytesIO(test_image))
test_array=keras.preprocessing.image.img_to_array(test_image)
batch_test_array=np.array([test_array])
predictions=resnet50b_model.predict(batch_test_array)
if type(predictions) == list:
average_prediction = sum(predictions)/len(predictions)
threshold_output = np.where(average_prediction > 0.5, 1, 0)
else :
threshold_output = np.where(predictions > 0.5, 1, 0)
response=str(predictions)+'-->'+str(threshold_output)
return response
@app.route('/predict_lenet5', methods=['GET','POST']) @app.route('/predict_lenet5', methods=['GET','POST'])
def predict_lenet5(): def predict_lenet5():
print('/predict_lenet5') print('/predict_lenet5')
@@ -126,13 +159,29 @@ def predict_lenet5():
response=str(predictions)+'-->'+str(threshold_output) response=str(predictions)+'-->'+str(threshold_output)
return response return response
# This method is used to process an image through PIL and send it back to the client. The client can then used this processed image as part of the training data
# so that the model can adapt to images that are processed through PIL
@app.route('/process_image', methods=['GET','POST'])
def process_image():
print('/process_image')
image=request.get_data()
image = PIL.Image.open(io.BytesIO(image))
imageByteArray = io.BytesIO()
image.save(imageByteArray, format='JPEG')
imageByteArray = imageByteArray.getvalue()
print('processed {length} bytes.'.format(length=len(imageByteArray)))
return send_file(io.BytesIO(imageByteArray), mimetype='image/jpeg', as_attachment=True, download_name='%s.jpg' % str(uuid.uuid4()))
if __name__ == '__main__': if __name__ == '__main__':
resnet50_model_name='../Weights/resnet50.h5' resnet50_model_name='../Weights/resnet50.h5'
resnet50_model = keras.models.load_model(resnet50_model_name) resnet50_model = keras.models.load_model(resnet50_model_name)
# vgg16_model_name='../Weights/model_vgg16.h5' resnet50b_model_name='../Weights/resnet50B.h5'
# vgg16_model=keras.models.load_model(vgg16_model_name) resnet50b_model = keras.models.load_model(resnet50b_model_name)
vgg16_model_name='../Weights/vggnet16.h5'
vgg16_model=keras.models.load_model(vgg16_model_name)
lenet_model_name='../Weights/lenet5.h5' lenet_model_name='../Weights/lenet5.h5'
lenet_model=keras.models.load_model(lenet_model_name) lenet_model=keras.models.load_model(lenet_model_name)

View File

@@ -38,7 +38,7 @@ from matplotlib import pyplot
import numpy as np import numpy as np
import tensorflow import tensorflow
def resnet50(input_shape,classes): def resnet50(input_shape,classes,model_name='resnet50'):
x_input=keras.Input(input_shape) x_input=keras.Input(input_shape)
x=Conv2D(64,(7,7),strides=(2,2),name='conv1')(x_input) x=Conv2D(64,(7,7),strides=(2,2),name='conv1')(x_input)
@@ -74,7 +74,7 @@ def resnet50(input_shape,classes):
x=Dense(classes,activation='sigmoid',name='fc'+str(classes))(x) x=Dense(classes,activation='sigmoid',name='fc'+str(classes))(x)
else: else:
x=Dense(classes,activation='softmax',name='fc'+str(classes))(x) x=Dense(classes,activation='softmax',name='fc'+str(classes))(x)
model=keras.Model(inputs=x_input,outputs=x,name='resnet50') model=keras.Model(inputs=x_input,outputs=x,name=model_name)
return model return model