import { OrbitControls } from "@react-three/drei";
import { Canvas, useFrame } from "@react-three/fiber";
import { useMemo, useRef, memo, useState } from "react";
import * as THREE from "three";
import fragmentShader from "./fragmentShader";
import vertexShader from "./vertexShader";

const useUniforms = () => {
  return useMemo(() => ({
    uTime: { value: 0.0 },
    uRadius: { value: 3 }, // Radius value
  }), []);
};

const RingParticles = memo(function RingParticles(props) {
  const { count, animation } = props; // Receive the animation prop
  const radius = 3;
  const points = useRef();
  const uniforms = useUniforms();
  const [animationProgress, setAnimationProgress] = useState(0); // Track progress

  const particlesPosition = useMemo(() => {
    const positions = new Float32Array(count * 3);
    for (let i = 0; i < count; i++) {
      const distance = Math.sqrt(Math.random()) * radius;
      const theta = THREE.MathUtils.randFloatSpread(360);
      const phi = THREE.MathUtils.randFloatSpread(360);
      const x = distance * Math.sin(theta) / Math.PI * Math.cos(phi); // added Math.PI
      const y = distance * Math.sin(theta) * Math.sin(phi);
      const z = distance * Math.cos(theta);
      positions.set([x, y, z], i * 3);
    }
    return positions;
  }, [count, radius]);

  const geometry = useMemo(() => {
    const geometry = new THREE.BufferGeometry();
    geometry.setAttribute(
      'position',
      new THREE.BufferAttribute(particlesPosition, 3)
    );
    return geometry;
  }, [particlesPosition]);

  const interpolatedPositions = useMemo(() => {
    const positions = new Float32Array(count * 3);
    for (let i = 0; i < count; i++) {
      const startPosition = particlesPosition.subarray(i * 3, i * 3 + 3);
      const targetPosition = [0, 0, 0]; // Always interpolate towards this

      const interpolatedPosition = startPosition.map((start, index) =>
        THREE.MathUtils.lerp(targetPosition[index], start, animationProgress)
      );

      positions.set(interpolatedPosition, i * 3);
    }
    return positions;
  }, [animationProgress, particlesPosition, count]);

  useFrame((state, delta) => {
    const { clock } = state;
    if (points.current) {
      points.current.material.uniforms.uTime.value = clock.elapsedTime;

      setAnimationProgress(prev => {
        const nextProgress = animation ? Math.min(prev + delta, 1) : Math.max(prev - delta, 0);
        return nextProgress;
      });

      points.current.geometry.setAttribute(
        'position',
        new THREE.BufferAttribute(interpolatedPositions, 3)
      );
      points.current.geometry.attributes.position.needsUpdate = true;
    }
  });

  return (
    <points ref={points} renderOrder={20} position={[0, -1, 0]} scale={[1, 4, 1]}>
      <primitive object={geometry} />
      <shaderMaterial
        blending={THREE.AdditiveBlending}
        depthWrite={false}
        fragmentShader={fragmentShader}
        vertexShader={vertexShader}
        uniforms={uniforms}
      />
    </points>
  );
});

export default RingParticles;
