import os
import glob
import re
from PIL import Image
from pathlib import Path
import shutil

delete_unused_images = True # if true, and an image is not referenced in any markdown file, it will be deleted
max_density = 400 # maximum image resolution, in dots per inch. You can set a very big value here if you don't want to resize images

full_width = 800 # width of the page, in pixels
page_width = 190 # in millimeters, without print margins

source_dir = os.path.dirname(__file__) + '/../source'
images_dir = os.path.dirname(__file__) + '/../source/img' # subdirectories are explored too
target_dir = os.path.dirname(__file__) + '/../source/img' # can be the same directory as images_dir, or another one

md_sources = list(glob.iglob(source_dir + '/**/*.md', recursive=True))

# Replace an image in all source files
def replace_image(original_rel_path, new_rel_path):
    for src_name in md_sources:
        with open(src_name) as src_file:
            original_contents = src_file.read()
            
        new_contents = original_contents.replace(original_rel_path, new_rel_path)
        if new_contents != original_contents:
            with open(src_name, 'w') as new_file:
                new_file.write(new_contents)

total_saved_space = 0

for image_path in (Path(images_dir).glob("**/*")):
    if image_path.suffix.lower() not in {".jpg", ".jpeg", ".png", ".svg"}: continue
    
    image_filename = image_path.name
    image_display_width = 0
    
    # compute target path relatively to the source folder
    image_rel_path = os.path.relpath(image_path.resolve(), images_dir)
    image_rel_path = os.path.relpath(target_dir + '/' + image_rel_path, source_dir)
    
    os.makedirs(os.path.dirname(source_dir + '/' + image_rel_path), exist_ok = True)
    
    if images_dir != target_dir and os.path.isfile(source_dir + '/' + image_rel_path): continue
    
    #print(image_rel_path)
    image_search = re.escape(image_rel_path)
    
    image = Image.open(image_path.resolve()) if image_path.suffix.lower() != '.svg' else None
    image_aspect = 1 if image is None else image.size[0] / image.size[1]
    
    original_size = os.path.getsize(image_path.resolve())
    
    for src_name in md_sources:
        with open(src_name) as src_file:
            src_contents = src_file.read()
            # ![alt text](image/path)
            for img_code in re.finditer('!\[.*\]\('+image_search+'\)', src_contents):
                #print(img_code.group(0))
                image_display_width = max(image_display_width, full_width)
            
            # <img src="image/path" width="w" height="h">
            for img_code in re.finditer('<img.*?src="'+image_search+'".*?>', src_contents):
                #print(img_code.group(0))
                width = full_width
                width_code = re.search('width="(.*?)[px]*"', img_code.group(0))
                if width_code is None:
                    height_code = re.search('height="(.*?)[px]*"', img_code.group(0))
                    if height_code is not None:
                        height = int(height_code.group(1))
                        width = int(image_aspect * height + 0.5)
                else:
                    width = int(width_code.group(1))
                image_display_width = max(image_display_width, width)
            
            # ```{image} img/vhelio.png :width: wpx :height: hpx```
            for img_code in re.finditer('```{image} '+image_search+'.*?```', src_contents, re.MULTILINE + re.DOTALL):
                #print(img_code.group(0))
                width = full_width
                width_code = re.search(':width:\s*(.*?)[px]*\s', img_code.group(0))
                if width_code is None:
                    height_code = re.search(':height:\s*(.*?)[px]*\s', img_code.group(0))
                    if height_code is not None:
                        height = int(height_code.group(1))
                        width = int(image_aspect * height + 0.5)
                else:
                    width = int(width_code.group(1))
                image_display_width = max(image_display_width, width)
    
    if image_display_width == 0:
        if delete_unused_images:
            print('WARNING: removing unused image ' + image_rel_path)
            os.remove(image_path.resolve())
            continue
        else:
            raise Exception('Image not found in source documents: ' + image_rel_path)
    
    if image is None:
        if images_dir != target_dir:
            shutil.copyfile(image_path.resolve(), source_dir + '/' + image_rel_path)
        continue
    
    #print(image_filename + ': width=' + str(image_info.max_width))
    
    image_width_inches = image_display_width / full_width * page_width / 25.4
    target_resolution_width = max(1, int(max_density * image_width_inches + 0.5))
    target_resolution_height = max(1, int(target_resolution_width/image.size[0]*image.size[1]+0.5))
    
    if target_resolution_width > image.size[0]:
        target_resolution_width = image.size[0]
        target_resolution_height = image.size[1]
    
    #print('Resizing image ' + image_filename + ' from ' + str(current_image.size[0]) + ' to ' + str(target_resolution_width))
    resized = image.resize((target_resolution_width,target_resolution_height), Image.Resampling.LANCZOS) if target_resolution_width != image.size[0] else image
    
    target_path = source_dir + '/' + image_rel_path
    if image_path.suffix.lower() == '.png':
        # Try to save the file as JPEG to see if it would be significantly smaller
        # This helps detecting files that should be JPEG, not PNG
        
        if resized.mode != 'RGB':
            background = Image.new('RGBA', resized.size, (255,255,255))
            alpha_composite = Image.alpha_composite(background, resized.convert('RGBA'))
            resized = alpha_composite.convert('RGB')
            
        png_path = target_path
        jpeg_path = png_path[0:-4] + '.jpg.tmp'
        png_path = png_path + '.tmp'
        
        resized.save(jpeg_path, format = 'JPEG', quality = 80)
        resized.save(png_path, format = 'PNG')
        
        # Force JPEG compression if it makes the image at least twice as small (in some cases, PNG can even give a smaller file)
        png_size = os.path.getsize(png_path)
        jpeg_size = os.path.getsize(jpeg_path)
        best_png_size = min(original_size, png_size)
        if jpeg_size < best_png_size - 200*1024 or jpeg_size < best_png_size / 2:
            os.remove(png_path)
            os.remove(target_path)
            os.rename(jpeg_path, jpeg_path[0:-4])
            print('WARNING: ' + image_rel_path + ' has been converted to JPEG format')
            
            replace_image(image_rel_path, os.path.relpath(jpeg_path[0:-4], source_dir))
            total_saved_space += original_size - jpeg_size
        else:
            os.remove(jpeg_path)
            if png_size < original_size - 100*1024 or png_size < original_size * 8/10:
                total_saved_space += original_size - png_size
                os.remove(target_path)
                os.rename(png_path, target_path)
                print('Recompressed PNG ' + image_rel_path)
            else:
                os.remove(png_path)
    else:
        tmp_path = target_path + '.tmp'
        resized.save(tmp_path, format = 'JPEG', quality = 80)
        tmp_size = os.path.getsize(tmp_path)
        if tmp_size < original_size - 100*1024 or tmp_size < original_size * 8/10:
            total_saved_space += original_size - tmp_size
            os.remove(target_path)
            os.rename(tmp_path, target_path)
            print('Recompressed JPEG ' + image_rel_path)
        else:
            os.remove(tmp_path)

print('Done. Saved ' + str(int(total_saved_space/1024+0.5)) + 'kB.')