#!/usr/bin/env python

# Generates (or updates) a unit test image mask, which is used to specify whether
# a pixel in the control image should be checked (black pixel in mask) or not (white
# pixel in mask). For non black or white pixels, the pixels lightness is used to
# specify a maximum delta for each color component

import os
import sys
import argparse
from PyQt4.QtCore import *
from PyQt4.QtGui import *
import struct
import urllib2
import glob

def error ( msg ):
  print msg
  sys.exit( 1 )

def colorDiff( c1, c2 ):
  redDiff = abs( qRed( c1 ) - qRed( c2 ) )
  greenDiff = abs( qGreen( c1 ) - qGreen( c2 ) )
  blueDiff = abs( qBlue( c1 ) - qBlue( c2 ) )
  alphaDiff = abs( qAlpha( c1 ) - qAlpha( c2 ) )
  return max( redDiff, greenDiff, blueDiff, alphaDiff )

def imageFromPath(path):
  if ( path[:7] == 'http://' or path[:7] == 'file://' ):
    #fetch remote image
    data = urllib2.urlopen(path).read()
    image = QImage()
    image.loadFromData(data)
  else:
    image = QImage( path )
  return image

def getControlImagePath(path):
  if os.path.isfile(path):
    return path

  #else try and find matching test image
  script_folder = os.path.dirname(os.path.realpath(sys.argv[0]))
  control_images_folder = os.path.join( script_folder, '../tests/testdata/control_images')

  matching_control_images = [x[0] for x in os.walk(control_images_folder) if path in x[0]]
  if len(matching_control_images) > 1:
    error('Found multiple matching control images for {}'.format(path))
  elif len(matching_control_images) == 0:
    error('No matching control images found for {}'.format(path))

  found_control_image_path = matching_control_images[0]

  #check for a single matching expected image
  images = glob.glob( os.path.join(found_control_image_path, '*.png') )
  filtered_images = [i for i in images if not i[-9:] == '_mask.png']
  if len(filtered_images) > 1:
    error('Found multiple matching control images for {}'.format(path))
  elif len(filtered_images) == 0:
    error('No matching control images found for {}'.format(path))

  found_image = filtered_images[0]
  print 'Found matching control image: {}'.format(found_image)
  return found_image

def updateMask(control_image_path, rendered_image_path, mask_image_path):
  control_image = imageFromPath( control_image_path )
  if not control_image:
    error('Could not read control image {}'.format(control_image_path))

  rendered_image = imageFromPath( rendered_image_path )
  if not rendered_image:
    error('Could not read rendered image {}'.format(rendered_image_path))
  if not rendered_image.width() == control_image.width() or not rendered_image.height() == control_image.height():
    print ('Size mismatch - control image is {}x{}, rendered image is {}x{}'.format(control_image.width(),
                                                                                   control_image.height(),
                                                                                   rendered_image.width(),
                                                                                   rendered_image.height()))

  max_width = min( rendered_image.width(), control_image.width() )
  max_height = min( rendered_image.height(), control_image.height() )

  #read current mask, if it exist
  mask_image = imageFromPath( mask_image_path )
  if mask_image.isNull():
    print 'Mask image does not exist, creating {}'.format( mask_image_path )
    mask_image = QImage( control_image.width(), control_image.height(), QImage.Format_ARGB32 )
    mask_image.fill( QColor(0,0,0) )

  #loop through pixels in rendered image and compare
  mismatch_count = 0
  linebytes = max_width * 4
  for y in xrange( max_height ):
    control_scanline = control_image.constScanLine( y ).asstring(linebytes)
    rendered_scanline = rendered_image.constScanLine( y ).asstring(linebytes)
    mask_scanline = mask_image.scanLine( y ).asstring(linebytes)

    for x in xrange( max_width ):
      currentTolerance = qRed( struct.unpack('I', mask_scanline[ x*4:x*4+4 ] )[0] )

      if currentTolerance == 255:
        #ignore pixel
        continue

      expected_rgb = struct.unpack('I', control_scanline[ x*4:x*4+4 ] )[0]
      rendered_rgb = struct.unpack('I', rendered_scanline[ x*4:x*4+4 ] )[0]
      difference = colorDiff( expected_rgb, rendered_rgb )

      if difference > currentTolerance:
        #update mask image
        mask_image.setPixel( x, y, qRgb( difference, difference, difference ) )
        mismatch_count += 1

  if mismatch_count:
    #update mask
    mask_image.save( mask_image_path, "png" )
    print 'Updated {} pixels in {}'.format( mismatch_count, mask_image_path )
  else:
    print 'No mismatches in {}'.format( mask_image_path )

parser = argparse.ArgumentParser() # OptionParser("usage: %prog control_image rendered_image mask_image")
parser.add_argument('control_image')
parser.add_argument('rendered_image')
parser.add_argument('mask_image', nargs='?', default=None)
args = parser.parse_args()

args.control_image = getControlImagePath( args.control_image)

if not args.mask_image:
  args.mask_image = args.control_image[:-4] + '_mask.png'

updateMask(args.control_image, args.rendered_image, args.mask_image)