from PIL import Image, ImageDraw, ImageChops, ImageFont
import aggdraw

import sys
import heapq
import random
from operator import itemgetter


def generate_ga(prob_mutate ,
                elitism, 
                popsize, 
                generation_count, 
                generators,
                create,
                fitness,
                express,
                save,
                load):
    elites = int(elitism * popsize)
    criteria = generators.keys()

    def generate():
        ret = {}
        for c in criteria:
            ret[c] = generators[c]()
        return ret

    def crossover(set1, set2):
        newset = []
        mutate_it = random.random() < prob_mutate * 3

        for i in range(0, len(set1 if len(set1) < len(set2) else set2)):
            if random.randint(0,100) < 50:
                newset.append(set1[i].copy())
            else:
                newset.append(set2[i].copy())

            if mutate_it and random.random() < prob_mutate:
                newset[i] = mutate(newset[i])

        return newset

    def mutate(specimen):
        characteristic = random.sample(criteria, 1)[0]
        specimen[characteristic] = generators[characteristic]()
        return specimen

    def repopulate(elitists):
        ret = {0: elitists[0]} 
        for k in range(2, popsize):
            prime = elitists[0] if random.random() < 0.5 else elitists[random.randint(1, len(elitists) - 1)]
            mates = [prime, elitists[random.randint(0, len(elitists) - 1)]]
            ret[k] = crossover(mates[0], mates[1])
        return ret

    def get_scores(pop):
        ret = {}
        for i in pop.keys():
            ret[i] = fitness(express(pop[i]))
        return ret

    def run():
        population = {}
        generation = 0
        winner = []
        load_result = load()
        if load_result != False:
            generation, winners = load_result
            population = repopulate(winners)
        else:
            for i in range(popsize):
                population[i] = create()

        for i in range(generation, generation_count):
            print 'generation #%d' % (i)
            scores = get_scores(population)
            score = sorted(scores.values())[0]
            print '  score: %d' % (score)
            selection = heapq.nsmallest(elites, scores.iteritems(), itemgetter(1))
            unselected = heapq.nlargest(popsize - elites, scores.iteritems(), itemgetter(1))
            print selection
            winners = [population[wnr[0]] for wnr in selection]
            losers = [population[lsr[0]] for lsr in unselected]
            winners = winners + random.sample(losers, 2)
            winner = winners[0]
            save(i, winners, winner, score)
            population = repopulate(winners)
        return winner

    return run



def circle_image(num_circles, image_file):
    def express(circles):
        image = Image.new('RGB', im.size)
        draw = aggdraw.Draw(image)
        for circle in circles:
            rad = circle['radius']
            pos = circle['position'] 
            brush = aggdraw.Brush(circle['color'], opacity=circle['color'][3])
            bounding_box = (pos[0] - rad, pos[1] - rad, pos[0] + rad, pos[1] + rad)
            draw.ellipse(bounding_box, brush)
        draw.flush()
        del draw
        return image


    def load():
        sys.path.append('./output')

        try:
            import state
            return (state.pass_number, state.winners)
        except ImportError:
            return false

    fnt = ImageFont.truetype('Verdana.ttf', 12)
    def save(generation, population, winner, winning_score):
        if generation % 100 == 0:
            express(winner).save('./output/pass_' + str(generation) + '.jpg')

        if generation % 5 == 0:
            rdr = express(winner)
            dr = ImageDraw.Draw(rdr)
            dr.text((0,0), 'generation: %d    score: %d' % (generation, winning_score), font=fnt, fill=(0,0,0))
            del dr
            rdr.save('../crumweb/current_ga_results.jpg')

        if generation % 20 == 0:
            fp = open('./output/state.py', 'w')
            fp.write('pass_number = ' + str(generation) + '\n')
            fp.write('winners = ' + str(population) + '\n')
            fp.close()

    im = Image.open(image_file)
    fitness = lambda candidate: sum([sum(pix) for pix in list(ImageChops.difference(im, candidate).getdata())])

    random.seed()

    generators = {
        'position': lambda: (random.randint(0, im.size[0]), random.randint(0, im.size[1])),
        'color': lambda: tuple([random.randint(0,255) for k in range(0,4)]),
        'radius': lambda: random.randint(1, im.size[0] / 2)
    }

    create = lambda: [generate() for k in range(num_circles)]

    return generate_ga(prob_mutate = 0.25, 
                    elitism = 0.40, 
                    popsize = 24, 
                    generation_count = 1000000, 
                    generators = generators,
                    create = create,
                    fitness = fitness,
                    express = express,
                    save = save,
                    load = load)

circle_image(140, '../crumweb/ryan.jpg')()
