#!/usr/bin/env python3

import sys
import os
import pickle

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # disable tensorflow messages
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

# keras is our machine learning library
import keras
from keras.preprocessing import image
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Flatten, Dense
from keras.activations import relu, elu

import numpy as np

import chances
import talos

from talos.model import network_shape
from talos.model.layers import hidden_layers

# rasterio is an library we are using to read multispectral images
import rasterio

# import our generator
from generators import SplitSetImageGenerator

################### HELPER FUNCTIONS ###################

def save_object(obj, filename):
	with open(filename, 'wb') as output:
		pickle.dump(obj, output, protocol=2)

def load_object(filename):
	with open(filename, 'rb') as f:
		return pickle.load(f)

def project_object(obj,*attributes):
	out={}
	for a in attributes:
		out[a]=getattr(obj,a)
	return out

def print_hyperparameter_search_stats(t):
	print(" *** params: ",{ p:(v if len(v)<200 else [v[0],v[1],v[2],'...',v[-1]]) for p,v in t['params'].items()})
	print()
	#print(" *** peak_epochs_df ",type(t['peak_epochs_df']),len(t['peak_epochs_df'].index))
	#print(t['peak_epochs_df'].to_string())
	#print()
	print(" *** data ",type(t['data']),len(t['data']))
	print(t['data'].sort_values('val_acc',ascending=False).to_string())
	print()
	print(" *** details ",type(t['details']),len(t['details']))
	print(t['details'])
	print()

read_image_cache={}
def read_image(path, rescale=None):
	key="{},{}".format(path,rescale)
	if key in read_image_cache:
		return read_image_cache[key]
	else:	
		with rasterio.open(path) as img:
			data=img.read()
			data=np.moveaxis(data,0,-1)
			if rescale!=None:
				data=data*rescale
			read_image_cache[key]=data
			return data

################### DATA CONFIGURATION ###################

# function to return filenames and classes of images
# also returns a list of classnames and a list of classindices coresponding to the clasnames
def image_data_generator_dir_reader(path):
	sys.stdout=sys.stderr # redirect problematic output
	# here we use the keras ImageDataGenerator to get a list of filenames and classes
	ig = image.ImageDataGenerator()
	gen = ig.flow_from_directory(path)
	sys.stdout=sys.__stdout__ # restore stdout
	names=[os.path.normpath(path+'/'+n.replace('\\','/')).replace('\\','/') for n in gen.filenames]
	return (names,gen.classes,*zip(*gen.class_indices.items()))

# build the data generators
test_validation_train_split=[0.2,0.4]
test_set,validation_set,training_set=[dataset.set(verbose=False) for dataset
		in SplitSetImageGenerator(image_load_function=read_image,scale=1.0/255)
				.add_dir(image_data_generator_dir_reader,'data/EuroSat/tif/')
				.shuffle()
				.split(*test_validation_train_split)]

# preload images to speed up training
for s in [validation_set,training_set]:
	s.set(verbose=True).preload().set(verbose=False).shuffle()

# ensure that each epoch presents the same number of images of each class
training_set.set(max_per_class_and_epoch=400)

#################### MODEL DEFINITION ###################
#		
## this is not an optimized model, just a simple example
## for good results this model needs some thought
#
#model = Sequential()
#
#model.add(Conv2D(60, 5, input_shape=training_set.shape))
#model.add(Activation('relu'))
#model.add(MaxPooling2D(pool_size=(2, 2)))
#
#model.add(Conv2D(20, 5, input_shape=training_set.shape))
#model.add(Activation('relu'))
#model.add(MaxPooling2D(pool_size=(2, 2)))
#
#model.add(Flatten())
#
#model.add(Dense(120))
#model.add(Dense(60))
#model.add(Dense(training_set.num_classes))
#model.add(Activation('softmax'))
#
#model.compile(loss='categorical_crossentropy',
#	      optimizer=Adam(lr=0.0001,decay=0.001),
#	      metrics=['accuracy'])
#
#################### TRAINING ###################
#
#history = model.fit_generator(
#	training_set,
#	validation_data=validation_set,
#	epochs=120
#)

################### HYPERPARAMETER OPTIMIZATION ###################
params = {
	'epoch': [120],
	'batch_size': [32],
	'activation': [relu],
	# convultion layers in the begining
	'conv_hidden_layers': [1,2,3],
	'conv_depth_shape': ['brick'],
	'conv_size_shape': ['brick'],
	'conv_depth_first_neuron': [20,40,60],
	'conv_depth_last_neuron': [20,40,60],
	'conv_size_first_neuron': [5,7],
	'conv_size_last_neuron': [3,5],
	# fully connected layers at the end
	'first_neuron': [32,64],
	'last_neuron': [32,64],
	'shapes': ['brick'],
	'hidden_layers': [1,2,3],
	'dropout': [0.05]
}

class __Config__(object):
	pass
config=__Config__()
config.optimizer=Adam
config.optimizer_parameters={'lr':0.0001,'decay':0.001}
config.loss='categorical_crossentropy'
config.metric=['accuracy']

def create_model(training_set,validation_set,verbose=False):
	def _create_conv_shape_(params):
		def shape(params):
			if params['hidden_layers']==1:
				return [params['first_neuron']]
			if params['hidden_layers']==2:
				return [params['first_neuron'],params['last_neuron']]
			else:
				params=params.copy()
				params['hidden_layers']-=2
				s_list=network_shape.network_shape(params,params['last_neuron'])
				return [params['first_neuron'],*s_list,params['last_neuron']]
		conv_depth_params={
			'hidden_layers': params['conv_hidden_layers'],
			'shapes': params['conv_depth_shape'],
			'first_neuron': params['conv_depth_first_neuron'],
			'last_neuron': params['conv_depth_last_neuron'],
		}
		conv_size_params={
			'hidden_layers': params['conv_hidden_layers'],
			'shapes': params['conv_size_shape'],
			'first_neuron': params['conv_size_first_neuron'],
			'last_neuron': params['conv_size_last_neuron'],
		}
		
		conv_depth_shape = shape(conv_depth_params)
		conv_size_shape = shape(conv_size_params)
		conv_shape=zip(conv_depth_shape,conv_size_shape)
		
		return conv_shape
#	def permutation_filter(shape,params):
#		conv_shape=_create_conv_shape_(params)
#		min_dim=min(shape[:2])
#		for i,(_,size) in enumerate(conv_shape):
#			min_dim-=size-1
#		return min_dim>=1

	def model(dummyXtrain,dummyYtrain,dummyXval,dummyYval,params):
		conv_shape=_create_conv_shape_(params)
		
		model = Sequential()

		for i,(depth,size) in enumerate(conv_shape):
			if i==0:
				model.add(Conv2D(depth, size, input_shape=training_set.shape))
			else:
				model.add(Conv2D(depth, size))
			model.add(Activation('relu'))
		
		model.add(Flatten())
		
		hidden_layers(model, params, params['last_neuron'])

		model.add(Dense(training_set.num_classes))
		model.add(Activation('softmax'))

		global config
		optimizer=config.optimizer(**config.optimizer_parameters)
		model.compile(loss=config.loss,
			      optimizer=optimizer,
			      metrics=config.metric)
		
		training_set.batch_size=params['batch_size']
		validation_set.batch_size=params['batch_size']
		
		history = model.fit_generator(
			training_set,
			validation_data=validation_set,
			epochs=params['epoch'],
			verbose=int(params['verbose']),
		)
		return history,model
#	model.permutation_filter=lambda params: permutation_filter(training_set.shape,params)
	return model


verbose=True
round_limit=30 # NOTE Set this to however many rounds you want to test with

model=create_model(training_set,validation_set,verbose)
#try:
#	permutation_filter=model.permutation_filter
#except:
#	permutation_filter=None
params['verbose']=[verbose]

dummyX,dummyY=training_set.__getitem__(0)
testX,testY=validation_set.__getitem__(0)
validation_set.on_epoch_end()
tt = talos.Scan( x=dummyX
		,y=dummyY
		,params=params
		,model=model
		,x_val=testX
		,y_val=testY
		#,dataset_name='EuroSat'
		#,experiment_no='1'
		,experiment_name='example.csv'
		,print_params=True
		#,permutation_filter=permutation_filter
		#,search_method='random'
		,round_limit=round_limit
		#,clear_tf_session=True
		)

print(vars(tt),dir(tt))
print(tt.round_history)

t = project_object(tt,'params','saved_models','saved_weights','data','details','round_history')
save_object(t,'example.pickle')

print_hyperparameter_search_stats(t)


