import cv2 as cv
import numpy as np
from centroidtracker import CentroidTracker
from trackableobject import TrackableObject
from flask import Flask, Response, render_template_string
import threading
import time

# Parameters
confThreshold = 0.1
nmsThreshold = 0.4
inpWidth = 320
inpHeight = 320

# Load class names
classesFile = "okok.names"
with open(classesFile, 'rt') as f:
    classes = f.read().rstrip('\n').split('\n')

# Load network
modelConfiguration = "fast1.cfg"
modelWeights = "fast1.weights"
print("[INFO] loading model...")
net = cv.dnn.readNetFromDarknet(modelConfiguration, modelWeights)
net.setPreferableBackend(cv.dnn.DNN_BACKEND_CUDA)
net.setPreferableTarget(cv.dnn.DNN_TARGET_CUDA)

# Initialize tracking
ct = CentroidTracker(maxDisappeared=40, maxDistance=50)
trackableObjects = {}

# Load counts or initialize
count_file_in = "/home/ryudasan/dataku/vehicle/hitungorang.txt"
count_file_out = "/home/ryudasan/dataku/vehicle/hitungorangout.txt"

try:
    with open(count_file_in, 'r') as file:
        counts_down = int(file.read())
except Exception:
    counts_down = 0

try:
    with open(count_file_out, 'r') as file:
        counts_up = int(file.read())
except Exception:
    counts_up = 0

# Define which classes to count (adjust according to your .names and classes of interest)
# For example:
# classes = ['person','bicycle','car','motorbike','aeroplane','bus','train','truck',...]
counted_class_names = ['motor', 'mobil', 'bus', 'truck']

# Map class name to counts
counts = {name: 0 for name in counted_class_names}

# Flask app and threading setup
app = Flask(__name__)

outputFrame = None
lock = threading.Lock()

def getOutputsNames(net):
    layersNames = net.getLayerNames()
    return [layersNames[i - 1] for i in net.getUnconnectedOutLayers()]

def iou_bbox(bbox1, bbox2):
    x1 = max(bbox1[0], bbox2[0])
    y1 = max(bbox1[1], bbox2[1])
    x2 = min(bbox1[2], bbox2[2])
    y2 = min(bbox1[3], bbox2[3])
    inter_area = max(0, x2 - x1) * max(0, y2 - y1)
    bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
    bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
    union = bbox1_area + bbox2_area - inter_area
    return inter_area / union if union > 0 else 0

def postprocess(frame, outs):
    frameHeight = frame.shape[0]
    frameWidth = frame.shape[1]

    classIds = []
    confidences = []
    boxes = []
    for out in outs:
        for detection in out:
            scores = detection[5:]
            classId = np.argmax(scores)
            confidence = scores[classId]
            if confidence > confThreshold:
                center_x = int(detection[0] * frameWidth)
                center_y = int(detection[1] * frameHeight)
                width = int(detection[2] * frameWidth)
                height = int(detection[3] * frameHeight)
                left = int(center_x - width / 2)
                top = int(center_y - height / 2)
                classIds.append(classId)
                confidences.append(float(confidence))
                boxes.append([left, top, width, height])

    indices = cv.dnn.NMSBoxes(boxes, confidences, confThreshold, nmsThreshold)

    detections = []  # To hold (left,top,right,bottom,classId,confidence)
    if len(indices) > 0:
        for i in indices.flatten():
            box = boxes[i]
            left = box[0]
            top = box[1]
            width = box[2]
            height = box[3]
            right = left + width
            bottom = top + height
            detections.append((left, top, right, bottom, classIds[i], confidences[i]))
    return detections

def counting(frame, objects, objectID_class):
    global counts, counts_up, counts_down

    frameHeight = frame.shape[0]
    frameWidth = frame.shape[1]

    # Adapted line position and range based on your original code (vertical line at x=230)
    line_y = 200
    allowed_range_y_min = 200 - 50
    allowed_range_y_max = 200 + 50

    for (objectID, centroid) in objects.items():
        to = trackableObjects.get(objectID, None)

        if to is None:
            to = TrackableObject(objectID, centroid)
        else:
            y = [c[1] for c in to.centroids]
            direction = centroid[0] - np.mean(y)  # horizontal movement for vertical line crossing

            to.centroids.append(centroid)

            if not to.counted and allowed_range_y_min <= centroid[1] <= allowed_range_y_max:
                cls_id = objectID_class.get(objectID, None)
                if cls_id is None:
                    continue

                cls_name = classes[cls_id]
                if cls_name not in counts:
                    continue

                # Crossing from left to right (object crosses line_x going right)
                if direction > 200 and centroid[0] > 0:
                    counts[cls_name] += 1
                    counts_up += 1   # overall up count
                    to.counted = True
                # Crossing from right to left (object crosses line_x going left)
                elif direction < 200 and centroid[0] < 0:
                    counts[cls_name] += 1
                    counts_down += 1  # overall down count
                    to.counted = True

        trackableObjects[objectID] = to

        # Draw centroid
        cv.circle(frame, (centroid[0], centroid[1]), 4, (0, 255, 0), -1)

    # Draw line
    cv.line(frame, (0, line_y), (frameWidth, line_y), (0, 255, 255), 2)

    # Draw counts on frame
    y0 = 30
    dy = 30
    for i, (cls_name, count) in enumerate(counts.items()):
        text = f"{cls_name}: {count}"
        cv.putText(frame, text, (10, y0 + i*dy), cv.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)

    # Also draw overall counts
    cv.putText(frame, f"Total Entered: {counts_down}", (10, y0 + len(counts)*dy), cv.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
    cv.putText(frame, f"Total Exited: {counts_up}", (10, y0 + (len(counts)+1)*dy), cv.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)

    # Save counts persistently
    try:
        with open(count_file_in, 'w') as f:
            f.write(str(counts_down))
        with open(count_file_out, 'w') as f:
            f.write(str(counts_up))
    except Exception as e:
        print("[WARN] Could not write counts to file:", e)

def video_capture_loop():
    global outputFrame, lock, counts_up, counts_down

    cap = cv.VideoCapture("http://127.0.0.1:7071")  # Adjust camera index or video file if needed
    print("[INFO] Starting video capture stream...")

    objectID_class = {}  # To map tracker ID to detected class id

    while True:
        ret, frame = cap.read()
        if not ret:
            print("[INFO] Failed to grab frame")
            break

        frameHeight = frame.shape[0]
        frameWidth = frame.shape[1]

        blob = cv.dnn.blobFromImage(frame, 1/255, (inpWidth, inpHeight), [0,0,0], 1, crop=False)
        net.setInput(blob)
        outs = net.forward(getOutputsNames(net))

        detections = postprocess(frame, outs)  # List of (left,top,right,bottom,classId,conf)

        rects = [(det[0], det[1], det[2], det[3]) for det in detections]

        objects = ct.update(rects)  # dict of objectID: centroid tuple (centroid from tracker)

        # Associate tracker IDs with detection classes by IoU matching
        new_objectID_class = {}
        for (objectID, centroid) in objects.items():
            object_box = None
            # Reconstruct bbox from centroid (approximate)
            # We find detection with highest IoU with this centroid
            max_iou = 0
            matched_class = None
            for det in detections:
                det_bbox = det[0:4]
                # Calculate centroid of det bbox
                cx = int((det_bbox[0] + det_bbox[2]) / 2)
                cy = int((det_bbox[1] + det_bbox[3]) / 2)
                dist = np.linalg.norm(np.array([centroid[0], centroid[1]]) - np.array([cx, cy]))
                if dist < 50:  # distance threshold, may tune for your setting
                    iou = iou_bbox((centroid[0]-15, centroid[1]-15, centroid[0]+15, centroid[1]+15), det_bbox)
                    if iou > max_iou:
                        max_iou = iou
                        matched_class = det[4]
            if matched_class is not None:
                new_objectID_class[objectID] = matched_class
            else:
                # fallback: keep previous class if any
                if objectID in objectID_class:
                    new_objectID_class[objectID] = objectID_class[objectID]

        objectID_class = new_objectID_class

        # Run counting based on this association
        counting(frame, objects, objectID_class)

        with lock:
            outputFrame = frame.copy()

        time.sleep(0.02)  # sleep to reduce CPU load

    cap.release()

def generate_frames():
    global outputFrame, lock
    while True:
        with lock:
            if outputFrame is None:
                continue
            ret, jpeg = cv.imencode('.jpg', outputFrame)
            if not ret:
                continue
            frame = jpeg.tobytes()

        yield (b'--frame\r\n'
               b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')

@app.route('/')
def index():
    html = """
    <html>
    <head><title>Vehicle Counter MJPEG Stream</title></head>
    <body>
    <h1>Vehicle Counter MJPEG Stream</h1>
    <img src="{{ url_for('video_feed') }}" width="900" />
    </body>
    </html>
    """
    return render_template_string(html)

@app.route('/video_feed')
def video_feed():
    return Response(generate_frames(),
                    mimetype='multipart/x-mixed-replace; boundary=frame')

if __name__ == "__main__":
    t = threading.Thread(target=video_capture_loop)
    t.daemon = True
    t.start()
    app.run(host='0.0.0.0', port=8080, debug=False)