#!/usr/bin/env python3

import sys
import os

# keras is our machine learning library
import keras
from keras.preprocessing import image
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Flatten, Dense

# import our generator
from generators import SplitSetImageGenerator

################### HELPER FUNCTIONS ###################

read_image_cache={}
def read_image(path, rescale=None):
	key="{},{}".format(path,rescale)
	if key in read_image_cache:
		return read_image_cache[key]
	else:
		img=image.load_img(path)
		data=image.img_to_array(img)
		if rescale!=None:
			data=data*rescale
		read_image_cache[key]=data
		return data

################### DATA CONFIGURATION ###################

# function to return filenames and classes of images
# also returns a list of classnames and a list of classindices coresponding to the clasnames
def image_data_generator_dir_reader(path):
	sys.stdout=sys.stderr # redirect problematic output
	# here we use the keras ImageDataGenerator to get a list of filenames and classes
	ig = image.ImageDataGenerator()
	gen = ig.flow_from_directory(path)
	sys.stdout=sys.__stdout__ # restore stdout
	names=[os.path.normpath(path+'/'+n.replace('\\','/')).replace('\\','/') for n in gen.filenames]
	return (names,gen.classes,*zip(*gen.class_indices.items()))

# build the data generators
test_validation_train_split=[0.2,0.4]
test_set,validation_set,training_set=[dataset.set(verbose=False) for dataset
		in SplitSetImageGenerator(image_load_function=read_image,scale=1.0/255)
				.add_dir(image_data_generator_dir_reader,'data/EuroSat/jpg/')
				.shuffle()
				.split(*test_validation_train_split)]

# preload images to speed up training
for s in [validation_set,training_set]:
	s.set(verbose=True).preload().set(verbose=False).shuffle()

# ensure that each epoch presents the same number of images of each class
training_set.set(max_per_class_and_epoch=400)

################### MODEL DEFINITION ###################
		
# this is not an optimized model, just a simple example
# for good results this model needs some thought

model = Sequential()

model.add(Conv2D(60, 5, input_shape=training_set.shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(20, 5, input_shape=training_set.shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())

model.add(Dense(120))
model.add(Dense(60))
model.add(Dense(training_set.num_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
	      optimizer=Adam(lr=0.0001,decay=0.001),
	      metrics=['accuracy'])

################### TRAINING ###################

history = model.fit_generator(
	training_set,
	validation_data=validation_set,
	epochs=30
)

