import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import uuid
from typing import Tuple

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

style_list = [
    {
        "name": "8K",
        "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
    },
    {
        "name": "4K",
        "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
    },
    {
        "name": "HD",
        "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
    },
    {
        "name": "BW",
        "prompt": "black and white collage of {prompt}. monochromatic, timeless, classic, dramatic contrast",
    },
    {
        "name": "Polar",
        "prompt": "collage of polaroid photos featuring {prompt}. vintage style, high contrast, nostalgic, instant film aesthetic",
    },
    {
        "name": "Mustard",
        "prompt": "Duotone style Mustard applied to {prompt}",
    },
    {
        "name": "Cinema",
        "prompt": "cinematic collage of {prompt}. film stills, movie posters, dramatic lighting",
    },
    {
        "name": "Coral",
        "prompt": "Duotone style Coral applied to {prompt}",
    },
    {
        "name": "Scrap",
        "prompt": "scrapbook style collage of {prompt}. mixed media, hand-cut elements, textures, paper, stickers, doodles",
    },
    {
        "name": "Fuchsia",
        "prompt": "Duotone style Fuchsia tone applied to {prompt}",
    },
    {
        "name": "Violet",
        "prompt": "Duotone style Violet applied to {prompt}",
    },
    {
        "name": "Pastel",
        "prompt": "Duotone style Pastel applied to {prompt}",
    },
    {
        "name": "Style Zero",
        "prompt": "{prompt}",
    },
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 530px;
}
"""

styles = {k["name"]: k["prompt"] for k in style_list}
DEFAULT_STYLE_NAME = "Style Zero"
STYLE_NAMES = list(styles.keys())

def apply_style(style_name: str, positive: str) -> str:
    if style_name in styles:
        p = styles[style_name]
        positive = p.format(prompt=positive)
    return positive

def set_wallpaper_size(size):
    if size == "Mobile (1080x1920)":
        return 1080, 1920
    elif size == "Desktop (1920x1080)":
        return 1920, 1080
    elif size == "Extented (1920x512)":  
        return 1920, 512
    elif size == "Headers (1080x512)":
        return 1080, 512
    else:
        return 1024, 1024  # Default return if none of the conditions are met

@spaces.GPU(duration=60, enable_queue=True)
def infer(prompt, seed=42, randomize_seed=False, wallpaper_size="Desktop(1920x1080)", num_inference_steps=4, style_name=DEFAULT_STYLE_NAME, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    width, height = set_wallpaper_size(wallpaper_size)

    styled_prompt = apply_style(style_name, prompt)
    
    options = {
        "prompt": styled_prompt,
        "width": width,
        "height": height,
        "guidance_scale": 0.0,
        "num_inference_steps": num_inference_steps,
        "generator": generator,
    }
    
    torch.cuda.empty_cache()  
    images = pipe(**options).images

    grid_img = Image.new('RGB', (width, height))
    grid_img.paste(images[0], (0, 0))

    unique_name = str(uuid.uuid4()) + ".png"
    grid_img.save(unique_name)
    return unique_name, seed

examples = [

    "chocolate dripping from a donut a yellow background",
    "cold coffee in a cup bokeh --ar 85:128 --style",
    "an anime illustration of a wiener schnitzel",
    "a delicious ceviche cheesecake slice, ultra-hd+",  
]

def load_predefined_images1():
    predefined_images1 = [
        "assets/ww.webp",
        "assets/xx.webp",
        "assets/yy.webp",
    ]
    return predefined_images1

with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 SIM""")
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Image(label="Result", show_label=False)

    with gr.Row(visible=True):
        wallpaper_size = gr.Radio(
            choices=["Mobile (1080x1920)", "Desktop (1920x1080)", "Extented (1920x512)", "Headers (1080x512)", "Default (1024x1024)"],
            label="Pixel Size(x*y)",
            value="Default (1024x1024)"
        )

        with gr.Row(visible=True):
            style_selection = gr.Radio(
                show_label=True,
                container=True,
                interactive=True,
                choices=STYLE_NAMES,
                value=DEFAULT_STYLE_NAME,
                label="Quality Style",
            )
        with gr.Accordion("Advanced Settings", open=True):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Row():
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=4,
                )
                
        gr.Examples(
            examples=examples,
            fn=infer,
            inputs=[prompt],
            outputs=[result, seed],
            cache_examples=False,
        )

    gr.on(
        triggers=[prompt.submit, run_button.click],
        fn=infer,
        inputs=[prompt, seed, randomize_seed, wallpaper_size, num_inference_steps, style_selection],
        outputs=[result, seed]
    )
    
    gr.Markdown("### Image Sample")
    predefined_gallery = gr.Gallery(label="## Images Sample", columns=3, show_label=False, value=load_predefined_images1())

    gr.Markdown("**Disclaimer/Note:**")
    
    gr.Markdown("🍕Model used in the space <a href='https://huggingface.co/black-forest-labs/FLUX.1-schnell' target='_blank'>black-forest-labs/FLUX.1-schnell</a>. More: 12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]")
    gr.Markdown("⚠️ users are accountable for the content they generate and are responsible for ensuring it meets appropriate ethical standards.")

demo.launch()