
import sys
import os
import math
import psutil

import numpy as np
import keras
from keras.utils import to_categorical

def sizeof_fmt(num, suffix='B'):
	for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
		if abs(num) < 1024.0:
			return "%3.1f%s%s" % (num, unit, suffix)
		num /= 1024.0
	return "%.1f%s%s" % (num, 'Yi', suffix)


# class SplitSetImageGenerator: signature
#	self.filenames
#	self.classes
#	self.classnames
#	self.batch_size
#	self.verbose
#	self.max_per_class_and_epoch
#	self.auto_shuffle
#	self.scale
#	self.image_cache
#	self.image_load_function # (path,scale=1,cache=dict|None) -> numpy.matrix
#	__init__(**all vars above**)
#		# intended as read only for client code
#		self.num_classes
#		self.shape
#		# intended as private
#		self.__indices__
#		self.__indices_for_class__
#	split(*float(0:1)) -> list(SplitSetImageGenerator)
#	shuffle() -> self
#	add_dir(path) -> self
#	preload() -> self
#	__getitem__(index)
#	__len__()
#	on_epoch_end()
#	NOTE to access classes and filenames read them directly from the arrays
#
# filenames and classes must alway have the same length, will throw otherwise
class SplitSetImageGenerator(keras.utils.Sequence):
	def __recompute_indices__(self):
		if self.max_per_class_and_epoch!=None:
			self.__indices__=[]
			for class_indices in self.__indices_for_class__:
				np.random.shuffle(class_indices)
				self.__indices__.extend(class_indices[:int(self.max_per_class_and_epoch)])
		else:
			self.__indices__=self.__all_indices__
	def __load_image__(self,filename):
		if self.scale!=None:
			try:
				return self.image_load_function(filename,float(self.scale),cache=self.image_cache)
			except:
				return self.image_load_function(filename,float(self.scale))
		else:
			try:
				return self.image_load_function(filename,cache=self.image_cache)
			except:
				return self.image_load_function(filename)
	# TODO possible enhancments
	#	* add a merge method that takes two SplitSetImageGenerator's
	#	* add a limit method that discards data to cap the number of images (optionally per class)
	def __init__(self,*,filenames=[],classes=[],classnames=[],image_cache=None,image_load_function,scale=None
			,batch_size=32,max_per_class_and_epoch=None,auto_shuffle=True,verbose=True):
		self.filenames=filenames
		self.classes=classes
		self.classnames=classnames
		self.image_cache=image_cache
		self.image_load_function=image_load_function
		self.scale=scale
		self.batch_size=batch_size
		self.max_per_class_and_epoch=max_per_class_and_epoch
		self.auto_shuffle=auto_shuffle
		self.verbose=verbose
		if self.verbose:
			print(" *** Initizialising DataGenerator")
			sys.stdout.flush()
		if len(set([filenames==[],classes==[],classnames==[]]))>1:
			raise Exception('ERROR if one of filenames, classes and classnames are non empty they must all be')
		if len(filenames)!=len(classes):
			raise Exception('Error missmatching length for files and classes')
		if max_per_class_and_epoch!=None and auto_shuffle==False:
			raise Exception('Error auto_shuffle requiered if max_per_class_and_epoch is specified')
		if len(filenames)>0:
			self.num_classes=len(self.classnames)
			self.shape=self.__load_image__(self.filenames[0]).shape
			self.__indices_for_class__=[[] for name in classnames]
			for i,c in enumerate(self.classes):
				self.__indices_for_class__[c].append(i)
			if self.verbose:
				for i,cls in enumerate(self.__indices_for_class__):
					print("Class {} with {} images".format(i,len(cls)))
				sys.stdout.flush()
		else:
			self.num_classes=0
			self.shape=None
			self.__indices_for_class__=[]
		self.__all_indices__=list(range(0,len(self.classes)))
		self.__recompute_indices__()
		self.__i__=0
	def set(self,**attributes):
		if 'verbose' in attributes and attributes['verbose']==False:
			pass
		elif self.verbose:
			print(" *** Setting attributes {}".format(attributes))
			sys.stdout.flush()
		for name,value in attributes.items():
			if name.startswith('_'):
				raise Exception('ERROR can not set underscore properties with SplitSetImageGenerator.set method')
			if not hasattr(self,name):
				raise Exception('ERROR no such property to set {} in SplitSetImageGenerator.set'.format(name))
			setattr(self,name,value)
		self.__recompute_indices__()
		return self
	def add_dir(self,image_dir_reader,*paths):
		if self.verbose:
			print(' *** Adding images to set')
			sys.stdout.flush()
		for path in paths:
			if self.verbose:
				print("Reading images from directory {}".format(path))
				sys.stdout.flush()
			names,rawclasses,classnames,classindices=image_dir_reader(path)
			mapping={}
			for name,cls in zip(classnames,classindices):
				if name in self.classnames:
					mapping[cls]=self.classnames.index(name)
				else:
					index=len(self.classnames)
					self.classnames.append(name)
					self.__indices_for_class__.append([])
					mapping[cls]=index
			classes=[mapping[v] for v in rawclasses]
			offset=len(self.filenames)
			for i,c in enumerate(classes):
				self.__indices_for_class__[c].append(i+offset)
			self.filenames.extend(names)
			self.classes.extend(classes)
			self.num_classes=len(self.classnames)
		if self.shape==None:
			self.shape=self.__load_image__(self.filenames[0]).shape
			if self.verbose:
				print("Image dimensions are {}".format(self.shape))
				sys.stdout.flush()
		if self.verbose:
			for i,cls in enumerate(self.__indices_for_class__):
				print("Class {} with {} images".format(i,len(cls)))
				sys.stdout.flush()
		self.__all_indices__=list(range(0,len(self.classes)))
		self.__recompute_indices__()
		return self
	def shuffle(self): # shuffles the list of __indices__
		if self.verbose:
			print(' *** Shuffle images')
			sys.stdout.flush()
		np.random.shuffle(self.__indices__)
		np.random.shuffle(self.__all_indices__)
		return self
	def preload(self): # load all images in the name list into memmory, will throw if a cache is not set
			   # this also will throw if the dimensions does not match for all images
		if self.verbose:
			print(' *** Preloading images')
			sys.stdout.flush()
		if self.verbose:
			print('memory {} used'.format(sizeof_fmt(psutil.Process(os.getpid()).memory_info().rss)))
			sys.stdout.flush()
		length=len(self.filenames)
		for i,name in enumerate(self.filenames):
			if self.verbose:
				sys.stdout.write("image {}/{}         \r".format(i,length))
				if i==length-1:
					sys.stdout.write("\n")
				sys.stdout.flush()
			img=self.__load_image__(name)
			if img.shape!=self.shape:
				raise Exception('Error missmatching dimensions {} expected {}'.format(img.shape,self.shape))
		if self.verbose:
			print('memory {} used'.format(sizeof_fmt(psutil.Process(os.getpid()).memory_info().rss)))
			sys.stdout.flush()
		return self
	def split(self,*splitpoints): # will throw if there are splitpoints outside 
				 # the closed interval 0-1 or duplicate split points
		if self.verbose:
			print(' *** Splitting images at splitpoints {}'.format(list(splitpoints)))
			sys.stdout.flush()
		for point in splitpoints:
			if point <= 0 or point >= 1:
				raise Exception('ERROR splitpoints must be between 0 and 1 non inclusive')
		#if len(splitpoints)!=len(set(splitpoints)):
		#	raise Exception('ERROR duplicate splitpoints not allowed')
		if splitpoints!=tuple(sorted(splitpoints)):
			raise Exception('ERROR splitpoints not monotonically increasing')
		
		points=(0,)+tuple(splitpoints)+(1,)
		def index(splitpoint):
			return int(math.ceil(len(self.filenames)*splitpoint))
		splits=[self.__all_indices__[index(start):index(end)] for start,end in zip(points,points[1:])]
		return [ SplitSetImageGenerator(
				filenames=[self.filenames[i] for i in indices],
				classes=[self.classes[i] for i in indices],
				classnames=self.classnames,
				image_cache=self.image_cache,
				image_load_function=self.image_load_function,
				scale=self.scale,
				batch_size=self.batch_size,
				max_per_class_and_epoch=self.max_per_class_and_epoch,
				auto_shuffle=self.auto_shuffle,
				verbose=self.verbose ) for indices in splits]
	def get_filenames(self,indices):
		return [self.filenames[self.__indices__[i]] for i in indices]
	def __getitem__(self,index): # gets the batch for the supplied index
		indices = self.__indices__[index*self.batch_size:(index+1)*self.batch_size]
		if len(indices)==0:
			return None
		self.__i__+=1
		if self.verbose:
			sys.stdout.write("batch {}/{} ({})         \r".format(self.__i__,self.__len__(),(self.__i__-1)*self.batch_size+len(indices)))
			if self.__i__==self.__len__():
				sys.stdout.write("\n")
			sys.stdout.flush()
		img=self.__load_image__(self.filenames[indices[0]])
		label=to_categorical(self.classes[indices[0]],num_classes=self.num_classes)
		X=np.empty((len(indices),)+img.shape,dtype='float32')
		Y=np.empty((len(indices),)+label.shape,dtype='float32')
		def addImage(img,label,index):
			X[index,]=img
			Y[index,]=label
		addImage(img,label,0)
		for i in range(1,len(indices)):
			img=self.__load_image__(self.filenames[indices[i]])
			label=to_categorical(self.classes[indices[i]],num_classes=self.num_classes)
			addImage(img,label,i)
		return X, Y
	def __len__(self): # gets the number of batches
		return int(math.ceil(len(self.__indices__)*1.0/self.batch_size))
	def on_epoch_end(self): # performs auto shuffle if enabled
		if self.verbose:
			print('memory {} used'.format(sizeof_fmt(psutil.Process(os.getpid()).memory_info().rss)))
			sys.stdout.flush()
		self.__i__=0
		if self.auto_shuffle:
			self.__recompute_indices__()
			self.shuffle()


