diff --git a/lab3/lab3.py b/lab3/lab3.py index b9c53bf..474f584 100644 --- a/lab3/lab3.py +++ b/lab3/lab3.py @@ -18,6 +18,8 @@ trt = False +if len(sys.argv) == 2: + trt = (sys.argv[1] == "trt") # Select device @@ -102,7 +104,7 @@ def predict(image): def process(image): - fig = plt.figure(figsize=(10, 10)) + fig = plt.figure(figsize=(6, 6)) sub = fig.add_subplot(1,1,1) index = predict(image) @@ -111,22 +113,16 @@ def process(image): sub.set_title(classes[index]) plt.axis('off') - image.thumbnail((128, 128)) + image.thumbnail((256, 256)) plt.imshow(image) plt.savefig('out/' + str(index) + '.png') # plt.show() +print("processing images...") -if __name__ == "__main__": +for i, image in enumerate(images): - if len(sys.argv) == 2: - trt = (sys.argv[1] == "trt") + print("processing {current} of {all}...".format(current = i + 1, all = len(images))) - print("processing images...") - - for i, image in enumerate(images): - - print("processing {current} of {all}...".format(current = i + 1, all = len(images))) - - process(image) \ No newline at end of file + process(image) \ No newline at end of file