import React, { useEffect } from "react";
import { useRef } from "react";

export interface FocusTrapProps {
    active: boolean;
    className?: string;
    onEsc?: () => void;
    setPreviousActiveElement?: (el: HTMLElement) => void;
}

const FocusTrap = ({
    children,
    active,
    className,
    onEsc,
    setPreviousActiveElement
}: React.PropsWithChildren<FocusTrapProps>) => {
    const trapRef = useRef<HTMLDivElement>(null);

    useEffect(() => {
        if (active && trapRef?.current) {
            //add any focusable HTML element you want to include to this string
            const focusableElements = trapRef.current.querySelectorAll(
                'button, [href]:not([tabindex="-1"]), input, select, textarea, [tabindex]:not([tabindex="-1"])'
            ) as NodeListOf<HTMLElement>;

            const firstElement = focusableElements[0];
            const lastElement = focusableElements[focusableElements.length - 1];

            const handleTabKeyPress = (event: KeyboardEvent) => {
                if (event.key === "Tab") {
                    if (event.shiftKey && document.activeElement === firstElement) {
                        event.preventDefault();
                        lastElement.focus();
                    } else if (!event.shiftKey && document.activeElement === lastElement) {
                        event.preventDefault();
                        firstElement.focus();
                    }
                }

                if (onEsc && event.key === "Escape") {
                    onEsc();
                }
            };

            trapRef.current.addEventListener("keydown", handleTabKeyPress);

            return () => {
                trapRef.current?.removeEventListener("keydown", handleTabKeyPress);
            };
        }
    }, [active, trapRef.current, children]);

    useEffect(() => {
        const focusableElements = trapRef.current?.querySelectorAll(
            'button, [href]:not([tabindex="-1"]), input, select, textarea, [tabindex]:not([tabindex="-1"])'
        ) as NodeListOf<HTMLElement>;
        const closeButton = trapRef.current?.querySelector("button.button-close-block") as HTMLElement;
        const firstElement = focusableElements[0];

        setPreviousActiveElement?.((document.activeElement as HTMLElement) || document.body);

        if (!closeButton) {
            firstElement?.focus();
        } else if (focusableElements[1]) {
            focusableElements[1].focus();
        } else {
            closeButton.focus();
        }
    }, []);

    return (
        <div className={className} ref={trapRef}>
            {children}
        </div>
    );
};

export default FocusTrap;
