import socket
import time
import psutil
import re
import os
import jinja2
from humanize_time import humanize_time
from utils import IP2NAME, BASE_DIR, SNAPSHOT_INTERVAL, TEMPLATE_DIR
from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetUtilizationRates, nvmlDeviceGetMemoryInfo, nvmlDeviceGetName, nvmlDeviceGetComputeRunningProcesses

def get_host_ip():
    try:
        s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM)
        s.connect(('8.8.8.8', 80))
        ip=s.getsockname()[0]
    finally:
        s.close()

    return ip

def get_gpu_proc_info(p):
    pid = p.pid
    mem_usage = toMB(p.usedGpuMemory)

    P = psutil.Process(pid)
    cmd = P.name() or '<unknown>'
    user = P.username() or '<unknown>'
    runtime = humanize_time(time.time() - P.create_time())

    return {
        "pid": pid,
        "user": user,
        "cmd": cmd,
        "runtime": runtime,
        "mem_usage": mem_usage
    }

toGB = lambda x: round(x / (2 ** 30))
toMB = lambda x: round(x / (2 ** 20))

def generate_html():
    machine_ip = get_host_ip()
    machine_name = IP2NAME[machine_ip] + '(' + machine_ip + ')'
    is_gpu_server = "CPU" not in machine_name

    template_name = "section-gpu.html" if is_gpu_server else "section-cpu.html"
    with open(os.path.join(TEMPLATE_DIR, template_name), encoding='utf-8') as f:
        template = jinja2.Template(f.read())

    while True:
        sys_data = {}

        sys_data["cpu"] = {
            "cores": psutil.cpu_count(logical=False),
            "threads": psutil.cpu_count(logical=True),
            "freq": psutil.cpu_freq().max / 1000,
            "percent": psutil.cpu_percent()
        }

        sys_data["mem"] = {
            "total": toGB(psutil.virtual_memory().total),
            "percent": psutil.virtual_memory().percent
        }

        sys_data["swap"] = {
            "total": toGB(psutil.swap_memory().total),
            "percent": psutil.swap_memory().percent
        }

        sys_data["disk"] = []
        partitions = psutil.disk_partitions()
        for partition in partitions:
            device = partition.device
            if not re.match(r'^/dev/sd[a-z]', device):
                continue

            usage = psutil.disk_usage(partition.mountpoint)
            if toGB(usage.total) < 100:
                continue
            
            sys_data["disk"].append({
                "name": partition.mountpoint,
                "total": toGB(usage.total),
                "used": toGB(usage.used),
                "free": toGB(usage.free),
                "percent": usage.percent
            })
        
        render_args = {
            "machine_name": machine_name,
            "update_time": time.strftime("%m-%d-%Y %H:%M:%S"),
            "sys_data": sys_data
        }
        
        if is_gpu_server:
            nvmlInit()
            gpu_count = nvmlDeviceGetCount()
            gpu_data = []

            for i in range(gpu_count):
                try:
                    handle = nvmlDeviceGetHandleByIndex(i)
                    name = nvmlDeviceGetName(handle)
                    if isinstance(name, bytes):
                        name = name.decode('utf-8')

                    mem_info = nvmlDeviceGetMemoryInfo(handle)
                    utilization = nvmlDeviceGetUtilizationRates(handle)
                    procs = nvmlDeviceGetComputeRunningProcesses(handle)

                    gpu_data.append({
                        "index": i,
                        "name": name,
                        "mem_total": toMB(mem_info.total),
                        "mem_used": toMB(mem_info.used),
                        "mem_free": toMB(mem_info.free),
                        "mem_percent": mem_info.used / mem_info.total * 100,
                        "utilization": utilization.gpu,
                        "procs": [get_gpu_proc_info(p) for p in procs]
                    })

                except Exception as e:
                    gpu_data.append({
                        "index": i,
                        "name": str(e),
                        "mem_total": 0,
                        "mem_free": 0,
                        "mem_usage": 0,
                        "utilization": 0,
                        "procs": []
                    })
            
            render_args["gpu_data"] = gpu_data
        
        html = template.render(**render_args)

        with open(os.path.join(TEMPLATE_DIR, "stat.html"), "w", encoding="utf-8") as f:
            f.write(html)
        
        print(f"[{time.strftime('%H:%M:%S')}] Generated HTML for {machine_name}")

        time.sleep(SNAPSHOT_INTERVAL)

if __name__ == "__main__":
    generate_html()