// src/components/CaseViewer.tsx

import React, { useState, useEffect, useCallback, useMemo, useRef } from 'react';
import { FaSpinner } from 'react-icons/fa';
import { useAuth } from '../context/AuthContext';
import PredictionsDisplay from './PredictionsDisplay';
import { Distance } from './ImageViewer';
import { useWorker } from '../context/WorkerContext';

import { NavbarConfig } from '../App';
import {
  generateReport,
  type StructuredReport,
  type ReportInput,
  organizeFindings,
  createSpatialGroups,
} from '../utils/reportGenerator';
import { PredictionsProvider } from '../context/PredictionsContext';
import { usePredictions } from '../context/PredictionsContext';
import ViewerControls from './ViewerControls';
import ImageViewer from './ImageViewer';
import { ViewerMode, ViewerMetrics, StudyData, FeedbackRecord } from './types';
import { useSeriesLoader } from '../hooks/useSeriesLoader';
import { extractFilenameFromUrl, getBaseFilename } from '../utils/fileHelpers';
import { API_BASE_URL } from '../config/api';
import { useOpenAIConfig } from '../context/OpenAIConfigContext';
import axios from 'axios';

interface BoundingBox {
  xmin: number;
  ymin: number;
  xmax: number;
  ymax: number;
}

interface SliceBoundingBox {
  slice_idx: number;
  slice_importance: number;
  bbox: BoundingBox;
}


export interface CaseViewerProps {
  studyData: StudyData & {
    feedback?: Record<string, [string, number, string]>;
    modelType: 'base' | 'experimental';
    report?: string;
  };
  setNavbarConfig: (config: NavbarConfig) => void;
  onFeedbackUpdate?: (studyId: string, feedback: FeedbackRecord) => void;
}

const fetchStudyInfo = async (studyId: string, token: string, userId: string) => {
  try {
    const response = await axios.post(
      `${API_BASE_URL}/get_study_info`,
      {
        user_id: userId,
        study_id: studyId
      },
      {
        headers: {
          "Content-Type": "application/json",
          "Ocp-Apim-Subscription-Key": "replace_with_subscription_key",
          "Authorization": `Bearer ${token}`
        }
      }
    );

    if (response.data.status === 200) {
      const metadata = response.data.metadata || {};
      return {
        base_model_prediction: response.data.base_model_prediction,
        experimental_model_prediction: response.data.experimental_model_prediction,
        seriesUrls: response.data.series_urls,
        selectedSeriesId: response.data.selected_series_id, 
        prelim: response.data.prelim || null,
        prelim_feedback: response.data.prelim_feedback || null, 
        report: response.data.report, 
        metadata: Object.entries(metadata).reduce((acc, [seriesId, meta]) => ({
          ...acc,
          [seriesId]: typeof meta === 'object' ? meta : {}
        }), {})
      };
    }
    return null;
  } catch (error) {
    console.error('Error fetching study info:', error);
    return null;
  }
};
const CaseViewer: React.FC<CaseViewerProps> = ({
  studyData,
  setNavbarConfig,
  onFeedbackUpdate
}) => {
  const { userId, token  } = useAuth();
  const { terminateAllWorkers } = useWorker();
  const { setPredictions, setBaseModelPredictions, setExperimentalModelPredictions} = usePredictions();
  const [study, setStudy] = useState<StudyData>({
    studyId: studyData.studyId,
    series: [],
    isLoading: false,
    base_model_prediction: studyData.base_model_prediction,
    experimental_model_prediction: studyData.experimental_model_prediction,
    seriesUrls: studyData.seriesUrls,
    modelSelectedSeriesId: studyData.modelSelectedSeriesId,
    prelim: studyData.prelim,
    report: studyData.report,
    metadata: studyData.metadata  
  });
  const [isLoadingPredictions, setIsLoadingPredictions] = useState(false);
  const { config } = useOpenAIConfig();
  const [spatialGroups, setSpatialGroups] = useState(() => 
    createSpatialGroups(studyData.experimental_model_prediction)
  );

// 2. Add a comprehensive refresh effect
useEffect(() => {
  const refreshStudy = async () => {
    if (!token || !userId) return;
    
    setIsLoadingPredictions(true);
    const studyInfo = await fetchStudyInfo(studyData.studyId, token, userId);
    setIsLoadingPredictions(false);

    if (studyInfo) {
      // Update study state
      setStudy(prev => ({
        ...prev,
        base_model_prediction: studyInfo.base_model_prediction,
        experimental_model_prediction: studyInfo.experimental_model_prediction,
        seriesUrls: studyInfo.seriesUrls,
        modelSelectedSeriesId: studyInfo.selectedSeriesId,
        metadata: studyInfo.metadata,
        prelim: studyInfo.prelim,
        report: studyInfo.report
      }));

      // Update predictions context
      setBaseModelPredictions(studyInfo.base_model_prediction || {});
      setExperimentalModelPredictions(studyInfo.experimental_model_prediction || {});
      setPredictions(studyInfo.experimental_model_prediction || studyInfo.base_model_prediction || {});

      // Update spatial groups
      const newSpatialGroups = createSpatialGroups(studyInfo.experimental_model_prediction);
      setSpatialGroups(newSpatialGroups);

      // Reset viewer metrics to ensure proper display
      setViewerMetrics(prev => ({
        ...prev,
        showBoundingBoxes: false,
        selectedConditions: new Set(),
        selectedModel: null,
        selectedGroup: null
      }));

      // Update report state if needed
      if (studyInfo.prelim) {
        setReport(studyInfo.prelim);
      }
    }
  };

  // Call refresh on mount and when dependencies change
  refreshStudy();

  // Cleanup function
  return () => {
    // Reset states on unmount
    setPredictions({});
    setBaseModelPredictions({});
    setExperimentalModelPredictions({});
    setSpatialGroups([]);
    setViewerMetrics(prev => ({
      ...prev,
      showBoundingBoxes: false,
      selectedConditions: new Set(),
      selectedModel: null,
      selectedGroup: null
    }));
  };
}, [token, userId, studyData.studyId, setBaseModelPredictions, setExperimentalModelPredictions, setPredictions]);

useEffect(() => {
  if (study.experimental_model_prediction || study.base_model_prediction) {
    // Update predictions context
    setPredictions(study.experimental_model_prediction || study.base_model_prediction || {});
    setBaseModelPredictions(study.base_model_prediction || {});
    setExperimentalModelPredictions(study.experimental_model_prediction || {});

    // Update spatial groups
    const newSpatialGroups = createSpatialGroups(study.experimental_model_prediction);
    setSpatialGroups(newSpatialGroups);
  }
}, [
  study.base_model_prediction, 
  study.experimental_model_prediction, 
  setPredictions, 
  setBaseModelPredictions, 
  setExperimentalModelPredictions
]);

  const npzWorker = useMemo(
    () => new Worker(new URL('../workers/npzWorker.ts', import.meta.url)),
    []
  );

  const {
    selectedSeriesId,
    handleSeriesSelect,
    loadingProgress,
  } = useSeriesLoader({
    study,
    setStudy,
    userId: userId || '',
    npzWorker
  });

  const [currentSlice, setCurrentSlice] = useState<number>(0);
  const [totalSlices, setTotalSlices] = useState<number>(0);
  const [isSeriesLoading, setIsSeriesLoading] = useState<boolean>(false);
  const [activeGroupId, setActiveGroupId] = useState<string | null>(null);
  const [measurements, setMeasurements] = useState<Distance[]>([]);

  const handleFeedbackUpdate = (studyId: string, updatedFeedback: FeedbackRecord) => {
    setStudy(prev => ({
      ...prev,
      feedback: updatedFeedback
    }));
    // Propagate the update to Worklist
    onFeedbackUpdate?.(studyId, updatedFeedback);
  };
  const handleClearMeasurements = useCallback(() => {
    setMeasurements([]);
  }, []);

  const handleSliceChange = useCallback((
    newSliceOrUpdater: number | ((prev: number) => number)
  ) => {
    const currentSeries = study.series.find(
      s => getBaseFilename(s.filename) === selectedSeriesId
    );
    
    // Calculate the target slice number
    const targetSlice = typeof newSliceOrUpdater === 'function' 
      ? newSliceOrUpdater(currentSlice)
      : newSliceOrUpdater;
    
    // Check if the slice data exists
    const isSliceAvailable = currentSeries?.data?.[targetSlice] !== undefined;
    
    // Only update if the slice is available
    if (isSliceAvailable) {
      setCurrentSlice(targetSlice);
    } else {
      // Optional: You could add a visual feedback here that the slice isn't ready
      console.log('Slice not yet loaded:', targetSlice);
    }
  }, [study.series, selectedSeriesId, currentSlice]);

  const LoadingProgress: React.FC<{ progress: number }> = ({ progress }) => {
    // Don't render anything if progress is 0 or 100
    if (progress <= 0 || progress >= 100) return null;
  
    return (
      <div className="absolute top-0 left-0 right-0 z-10"> {/* Changed z-50 to z-10 */}
        <div className="h-1 bg-gray-700/50">
          <div 
            className="h-full bg-primary transition-all duration-300 ease-out"
            style={{ width: `${progress}%` }}
          />
        </div>
        <div className="absolute top-2 left-1/2 -translate-x-1/2 bg-gray-800/90 text-xs text-gray-300 px-2 py-1 rounded-full">
          Loading slices: {Math.round(progress)}%
        </div>
      </div>
    );
  };

  const [viewerMetrics, setViewerMetrics] = useState<ViewerMetrics>({
    zoom: 1,
    pan: { x: 0, y: 0 },
    showHUValue: false,
    showCrosshair: false,
    showMeasurements: false,
    showBoundingBoxes: false,
    selectedConditions: new Set(),
    selectedModel: null,
    currentSlice: 0,
    selectedGroup: null,
  });

  // Windowing state
  const [windowCenter, setWindowCenter] = useState<number>(40);
  const [windowWidth, setWindowWidth] = useState<number>(400);
  const [showMiscSections, setShowMiscSections] = useState(false);

  const [report, setReport] = useState<StructuredReport | null>(studyData.prelim || null);
  const [isGeneratingReport, setIsGeneratingReport] = useState(false);

  const currentSeries = useMemo(() => {
    return study.series.find(
      s => getBaseFilename(s.filename) === selectedSeriesId
    );
  }, [study.series, selectedSeriesId]);

  // Add a ref to keep track if the initial slice has been set
  const isInitialSliceSet = useRef(false);

  const handleRegenerateReport = useCallback(async () => {
    if (!userId || !token) return;

    setIsGeneratingReport(true);
    try {
      const predictions = study.experimental_model_prediction
      const generatedReport = await generateReport(predictions, { token, userId });

      if (generatedReport) {
        setReport(generatedReport);
      }
    } catch (error) {
      console.error('Error regenerating report:', error);
    } finally {
      setIsGeneratingReport(false);
    }
  }, [userId, token, study.experimental_model_prediction, setReport]);

  // Add an effect to sync report state with study
  useEffect(() => {
    if (study.prelim) {
      setReport(study.prelim);
    }
  }, [study.prelim]);
  

  const fetchReport = useCallback(async () => {
    if (!userId || !studyData.studyId) return;

    setIsGeneratingReport(true);
    try {
      const predictions = studyData.experimental_model_prediction;

      const generatedReport = await generateReport(predictions, { token, userId });

      if (generatedReport) {
        setReport(generatedReport);
      }
    } catch (error) {
      console.error('Error fetching/generating report:', error);
    } finally {
      setIsGeneratingReport(false);
    }
  }, [userId, studyData.studyId, studyData.experimental_model_prediction, config]);

  // Add useEffect to handle prelim data
  useEffect(() => {
    if (studyData.prelim) {
      setReport(studyData.prelim);
    }
  }, [studyData.prelim]);
  
  useEffect(() => {
    const predictions = studyData.experimental_model_prediction;
    setPredictions(predictions);
  }, [studyData.experimental_model_prediction, setPredictions]);

  const handleWindowPreset = useCallback(
    (preset: { center: number; width: number }) => {
      setWindowCenter(preset.center);
      setWindowWidth(preset.width);
    },
    []
  );

  const [viewerMode, setViewerMode] = useState<ViewerMode>(ViewerMode.Default);

  const handleConditionSelect = useCallback(
    async (
      condition: string | null,
      modelType: 'base' | 'experimental' | null
    ) => {
      if (!condition || !modelType) {
        // Clear selection
        setViewerMetrics((prev) => ({
          ...prev,
          showBoundingBoxes: false,
          selectedConditions: new Set(),
          selectedModel: null,
        }));
        return;
      }
      const predictions =
        modelType === 'base'
          ? study.base_model_prediction
          : study.experimental_model_prediction;

      const bboxKey = `bbox_${condition}`;
      const bboxData = predictions[bboxKey];

      if (!bboxData || bboxData === '{}') return;

      try {
        const parsedBBox =
          typeof bboxData === 'string'
            ? JSON.parse(bboxData)
            : (bboxData as {
                slices: SliceBoundingBox[];
              });

        if (!parsedBBox.slices?.length) return;

        const modelSeriesId = study.modelSelectedSeriesId;
        if (!modelSeriesId) return;

        // If we're not on the correct series, switch to it first
        if (modelSeriesId !== selectedSeriesId) {
          await handleSeriesSelect(modelSeriesId);
        }

          const bestSlice = parsedBBox.slices.reduce(
            (best: SliceBoundingBox, current: SliceBoundingBox) =>
              current.slice_importance > best.slice_importance ? current : best,
            parsedBBox.slices[0]
          );

        const currentSeries = study.series.find(
          (s) => getBaseFilename(s.filename) === getBaseFilename(modelSeriesId)
        );
        const totalSlices = currentSeries?.dims[0] || 0;
        if (!totalSlices) return;

        const flippedSliceIndex = totalSlices - 1 - bestSlice.slice_idx;
        setCurrentSlice(flippedSliceIndex);

        setViewerMetrics((prev) => ({
          ...prev,
          showBoundingBoxes: true,
          selectedConditions: new Set([condition]),
          selectedModel: modelType,
        }));
      } catch (error) {
        console.error('Error parsing bounding box data:', error);
      }
    },
    [
      study.experimental_model_prediction,
      study.base_model_prediction,
      study.modelSelectedSeriesId,
      study.seriesUrls,
      study.series,
      selectedSeriesId,
      handleSeriesSelect,
      setCurrentSlice,
      setViewerMetrics,
    ]
  );

  // Update useEffect to set totalSlices when series changes
  useEffect(() => {
    const currentSeries = study.series.find(
      (s) => getBaseFilename(s.filename) === selectedSeriesId
    );
    
    if (!currentSeries?.dims) {
      setIsSeriesLoading(true);
      return;
    }
    
    const newTotalSlices = currentSeries.dims[0];
    if (newTotalSlices === totalSlices) return;
    
    setTotalSlices(newTotalSlices);
    setCurrentSlice(0); // Start from first slice instead of middle
  }, [study.series, selectedSeriesId, totalSlices]);


  const handleViewerMetricsUpdate = useCallback(
    (value: React.SetStateAction<ViewerMetrics>) => {
      setViewerMetrics(value);
    },
    []
  );

  useEffect(() => {
    // Reset viewer metrics when selectedSeriesId changes
    setViewerMetrics((prev) => ({
      ...prev,
      showBoundingBoxes: false,
      selectedConditions: new Set(),
      selectedModel: null,
    }));
    // Optionally reset current slice if needed
    setCurrentSlice(0);
  }, [study.selectedSeriesId]);

  const handleGroupSelect = useCallback(
    async (groupId: string | null) => {
      if (!groupId) {
        // Clear selection
        setActiveGroupId(null);
        setViewerMetrics((prev) => ({
          ...prev,
          showBoundingBoxes: false,
          selectedGroup: null,
        }));
        return;
      }

      // Find the selected spatial group
      const selectedGroup = spatialGroups.find((group) => group.id === groupId);
      if (!selectedGroup) return;

      const modelSeriesId = study.modelSelectedSeriesId;
      if (!modelSeriesId) return;


      if (
        modelSeriesId !==
        selectedSeriesId
      ) {
        await handleSeriesSelect(modelSeriesId);
      }

      const currentSeries = study.series.find(
        (s) => getBaseFilename(s.filename) === getBaseFilename(modelSeriesId)
      );
      const totalSlices = currentSeries?.dims[0] || 0;
      if (!totalSlices) return;

      // Calculate target slice (average of start and end slices)
      const targetSliceIndex = totalSlices - 1 - Math.round((selectedGroup.bbox.startSlice + selectedGroup.bbox.endSlice) / 2);
      setCurrentSlice(targetSliceIndex);

      // Update active group state and viewer metrics
      setActiveGroupId(groupId);
      setViewerMetrics((prev) => ({
        ...prev,
        showBoundingBoxes: true,
        selectedGroup: groupId,
      }));
    },
    [
      spatialGroups,
      study.modelSelectedSeriesId,
      selectedSeriesId,
      study.seriesUrls,
      study.series,
      setCurrentSlice,
      setViewerMetrics,
      setActiveGroupId,
      handleSeriesSelect,
    ]
  );

    // Add a new state to track initial loading
  const [initialLoadComplete, setInitialLoadComplete] = useState(false);

  // Modify the loading progress display condition
  {loadingProgress > 0 && !initialLoadComplete && (
    <LoadingProgress progress={loadingProgress} />
  )}

  // Reset initialLoadComplete when changing series
  useEffect(() => {
    setInitialLoadComplete(false);
  }, [selectedSeriesId]);

    // Add this near the other useEffect hooks
  useEffect(() => {
    // Cleanup function that runs when component unmounts
    return () => {
      terminateAllWorkers();
    };
  }, [terminateAllWorkers]); // Only depends on terminateAllWorkers

  useEffect(() => {
    if (study.experimental_model_prediction) {
      const newSpatialGroups = createSpatialGroups(study.experimental_model_prediction);
      setSpatialGroups(newSpatialGroups);
    }
  }, [study.experimental_model_prediction]);

  return (
    <PredictionsProvider
    initialPredictions={study.experimental_model_prediction || {}}
    initialModel="experimental"
    >
      <div className="flex h-screen bg-gray-900 text-white overflow-hidden">
        <div className="flex-1 flex">
          {/* Left sidebar - Always show controls */}
          <div className="w-64 flex-shrink-0 p-4 overflow-y-auto">
            <ViewerControls
              series={study.series}
              selectedSeriesId={selectedSeriesId}
              onSelect={handleSeriesSelect}
              studyId={study.studyId}
              allSeriesUrls={study.seriesUrls}
              windowCenter={windowCenter}
              windowWidth={windowWidth}
              onWindowingChange={(center: number, width: number) => {
                setWindowCenter(center);
                setWindowWidth(width);
              }}
              viewerMode={viewerMode}
              onViewerModeChange={setViewerMode}
              measurements={[]}
              onClearMeasurements={handleClearMeasurements}
              onWindowPresetChange={handleWindowPreset}
              viewerMetrics={viewerMetrics}
              setViewerMetrics={handleViewerMetricsUpdate}
              study={study}
              onLoadSeries={handleSeriesSelect}
              currentSlice={currentSlice}
              totalSlices={totalSlices}
              onSliceChange={handleSliceChange}
            />
          </div>

          {/* Main content area */}
          <div className="flex-1 flex">
            {/* Center canvas - Aligned with sidebars */}
            <div className="flex-1 relative flex items-center justify-center p-4">
              <div className="w-full h-full flex items-center justify-center relative">
                {loadingProgress > 0 && loadingProgress < 100 && (
                  <LoadingProgress progress={loadingProgress} />
                )}
              {(!currentSeries?.data || !currentSeries.data[currentSlice]) && loadingProgress < 100 ? (
                <div className="w-full h-full flex items-center justify-center bg-gray-800/50 backdrop-blur-sm rounded-lg">
                  <div className="max-w-md w-full mx-4">
                    <div className="bg-gray-800 border border-gray-700 rounded-xl p-6 shadow-2xl">
                      <div className="mb-6 text-center">
                        <div className="inline-block p-3 bg-gray-700/50 rounded-full mb-4">
                          <FaSpinner className="animate-spin text-primary w-8 h-8" />
                        </div>
                        <h3 className="text-lg font-semibold text-white">
                          Initializing Series...
                        </h3>
                      </div>
                    </div>
                  </div>
                </div>
              ) : (
                <ImageViewer
                  study={study}
                  selectedSeriesId={selectedSeriesId}
                  currentSlice={currentSlice}
                  setCurrentSlice={handleSliceChange}
                  viewerMetrics={viewerMetrics}
                  setViewerMetrics={setViewerMetrics}
                  windowCenter={windowCenter}
                  setWindowCenter={setWindowCenter}
                  windowWidth={windowWidth}
                  setWindowWidth={setWindowWidth}
                  viewerMode={viewerMode}
                  setViewerMode={setViewerMode}
                  totalSlices={totalSlices}
                  setTotalSlices={setTotalSlices}
                  handleWindowPreset={handleWindowPreset}
                  baseModelPredictions={study.base_model_prediction}
                  experimentalModelPredictions={study.experimental_model_prediction}
                  spatialGroups={spatialGroups}
                  activeGroupId={activeGroupId}
                  measurements={measurements}
                  setMeasurements={setMeasurements}
                  onClearMeasurements={handleClearMeasurements}
                  availableSlices={
                    currentSeries?.data
                      ?.map((slice: Int16Array[] | undefined, index: number) => 
                        slice ? index : undefined
                      )
                      .filter((x): x is number => x !== undefined) ?? []
                  }
                  currentSeries={currentSeries}
                />
              )}
              </div>
            </div>

            {/* Right sidebar */}
            <div className="w-96 flex-shrink-0 h-screen overflow-hidden">
              <div className="h-full p-2">
                <PredictionsDisplay
                  studyId={studyData.studyId}
                  feedback={studyData.feedback}
                  onConditionSelect={handleConditionSelect}
                  setViewerMetrics={handleViewerMetricsUpdate}
                  viewerMetrics={viewerMetrics}
                  report={report}
                  isGeneratingReport={isGeneratingReport}
                  onRegenerateReport={handleRegenerateReport}
                  baseModelPredictions={study.base_model_prediction || {}}
                  experimentalModelPredictions={study.experimental_model_prediction || {}}
                  spatialGroups={spatialGroups}
                  setCurrentSlice={setCurrentSlice}
                  totalSlices={totalSlices}
                  onGroupSelect={handleGroupSelect}
                  activeGroupId={activeGroupId}
                  prelim={study.prelim} 
                  prelim_feedback={studyData.prelim_feedback || undefined}
                  studyReport={study.report}
                  onFeedbackUpdate={handleFeedbackUpdate}
                />
              </div>
            </div>
          </div>
        </div>
      </div>
    </PredictionsProvider>
  );
};

export default CaseViewer;
