Skip to content
This repository was archived by the owner on Jan 1, 2023. It is now read-only.

Commit 2afaaa5

Browse files
authored
Merge pull request #15 from norkator/Feature/gpu-memory-problem-fix
Feature/gpu memory problem fix
2 parents 515e688 + 40f1195 commit 2afaaa5

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

libraries/fast_srgan/infer_oi.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from argparse import ArgumentParser
2+
import tensorflow as tf
23
from tensorflow import keras
4+
from keras import backend as kb
35
import numpy as np
46
import cv2
57
import os
@@ -11,15 +13,22 @@
1113
# Model path
1214
model_path_file_name = os.getcwd() + '/libraries/fast_srgan/models/generator.h5'
1315

14-
# Load model to memory
15-
# Change model input shape to accept all size inputs
16-
model = keras.models.load_model(model_path_file_name, compile=False)
17-
inputs = keras.Input((None, None, 3))
18-
output = model(inputs)
19-
model = keras.models.Model(inputs, output)
16+
# Set Keras TensorFlow session config
17+
config = tf.compat.v1.ConfigProto()
18+
config.gpu_options.per_process_gpu_memory_fraction = 0.9 # Use 90%
19+
config.gpu_options.allow_growth = True
20+
tf_session = tf.compat.v1.Session(config=config)
21+
tf.compat.v1.keras.backend.set_session(tf_session)
2022

2123

2224
def process_super_resolution_images(sr_image_objects):
25+
# Load model to memory
26+
# Change model input shape to accept all size inputs
27+
model = keras.models.load_model(model_path_file_name, compile=False)
28+
inputs = keras.Input((None, None, 3))
29+
output = model(inputs)
30+
model = keras.models.Model(inputs, output)
31+
2332
# Loop over all images
2433
# Input and output image is full path + filename including extension
2534
for sr_image_object in sr_image_objects:
@@ -36,7 +45,8 @@ def process_super_resolution_images(sr_image_objects):
3645
print('Sr processing img size: ' + str(original_h) + ':' + str(original_w))
3746

3847
# Check if image is not too big
39-
if original_w < 1200 and original_h < 1200:
48+
# Original size was 1200 but testing with smaller
49+
if original_w < 1000 and original_h < 1000:
4050
# Convert to RGB (opencv uses BGR as default)
4151
low_res = cv2.cvtColor(low_res, cv2.COLOR_BGR2RGB)
4252

@@ -57,11 +67,19 @@ def process_super_resolution_images(sr_image_objects):
5767

5868
# Save sr image data to object
5969
sr_image_object.set_sr_image_data(sr)
70+
71+
# Clear sr object
72+
sr = None
6073
else:
6174
# Save original image
6275
sr_image_object.set_sr_image_data(low_res)
6376

6477
except Exception as e:
6578
print(e)
6679

80+
# Clear model
81+
model = None
82+
tf.compat.v1.reset_default_graph() # Try free memory from Tensorflow
83+
kb.clear_session() # Clear Keras session
84+
6785
return sr_image_objects

0 commit comments

Comments
 (0)