import { useState, useEffect, useMemo } from 'react'
import moment from 'moment'
import _ from 'lodash'

import { getMMMRuns, getMMMData } from '../../api/attributionModel'
import { buildChannelTactics, roundNumber, calcMaxCurvature, findClosestPoint, calcPercentDiff } from './helpers'

export const useModelData = (modelId) => {
  const [models, setModels] = useState([])
  const [selectedModel, setSelectedModel] = useState()
  const [modelData, setModelData] = useState()

  const { trainingStats } = modelData || {}

  // get available models and fetch most recent one
  useEffect(() => {initData()}, [])
  const initData = async () => {
    const res = await getMMMRuns()
    const runs = res?.model_run
    setModels(runs)
    const latestRun = runs?.[0]
    setSelectedModel(latestRun)
  }

  useEffect(() => {
    if (!selectedModel) return
    getMMMData(selectedModel)
      .then(data => setModelData(data))
  }, [selectedModel])

  // generate nicely formatted list of features and channels
  const [features, channels] = useMemo(() => {
    if (!modelData) return [[], []]
    const rawFeatures = modelData?.features
    const features = buildChannelTactics(rawFeatures)
    const channels = _.uniq(features.map(f => f.channel))
    return [features, channels]
  }, [modelData])

  // calculate response curves for each feature
  const responseCurves = useMemo(() => {
    if (!modelData || !features?.length) return

    const { raw } = modelData
    return features.reduce((acc, f) => {
      const featureKey = f.value

      // filter response curves by training spend range
      const { spend_min, spend_max } = trainingStats.find(x => x.feature === featureKey) || {}
      const res = raw.response_curves
        .filter(row => row.feature_name == featureKey)
        .filter(row => row.spend_input >= spend_min && row.spend_input <= spend_max)

      // remove duplicate spend values and calc roas
      const curve = _.uniqBy(res, 'spend_input')
        .sort((a, b) => b.spend_input - a.spend_input) // sort in descending order
        .map(row => Object.assign({
          predicted_roas: row.marginal_response / row.spend_input,
        }, row))

      acc[featureKey] = curve
      return acc
    }, {})
  }, [modelData, features])

  // calculate key stats for each feature
  const featureStats = useMemo(() => {
    if (!responseCurves || !trainingStats?.length) return
    return Object.keys(responseCurves).reduce((acc, featureKey) => {
      const curve = responseCurves[featureKey]
      if (!curve?.length) return acc

      // min/max roas
      const [min, max] = d3.extent(curve.map(({ predicted_roas }) => predicted_roas))
      const minRoas = roundNumber(min, 1) || 0.1
      const maxRoas = roundNumber(max, 1) || 0.1

      // get "current" spend by finding point on response curve closest to mean
      const featureTrainingStats = trainingStats.find(x => x.feature === featureKey) || {}
      const historicalSpendMin = featureTrainingStats?.spend_min || 0
      const historicalSpendMax = featureTrainingStats?.spend_max || 0
      const meanSpend = featureTrainingStats?.spend_mean || 0

      const currentSpend = findClosestPoint(curve, 'spend_input', meanSpend)
      const currentRoas = curve.filter(row => row.spend_input == currentSpend)[0]?.predicted_roas || 0
      const currentIndex = curve.findIndex(row => row.spend_input == currentSpend) || 0


      const pointClosestToOne = curve.filter(row => row.predicted_roas == findClosestPoint(curve, 'predicted_roas', 1))[0] || {}


      // optimal spend, revenue, roas
      // const [optimalSpend, optimalRevenue, optimalRoas] = calcMaxCurvature(curve)

      // recommended spend, revenue, roas - max 25% diff from current
      // const optimalToCurrentDiff = calcPercentDiff(optimalSpend, currentSpend)
      const useOptimal = false //currentSpend === 0 || optimalToCurrentDiff <= 0.25
      const recommendIncrease = currentRoas > 1
      const directionalIndex = recommendIncrease ?
        d3.max([currentIndex - 1, 0]):
        d3.min([currentIndex + 1, curve.length-1]);


      const directionalChange = curve[directionalIndex].spend_input;


      const maxRecommendedChange = recommendIncrease ?
            d3.max([currentSpend * 1.25, directionalChange]):
            d3.min([currentSpend * 0.75, directionalChange]);

      const recommendation = currentSpend == 0 ?
        0 : recommendIncrease ?
          d3.min([pointClosestToOne.spend_input || Infinity, maxRecommendedChange]) :
          d3.max([pointClosestToOne.spend_input || -Infinity, maxRecommendedChange]);



      const recommendedSpend = findClosestPoint(curve, 'spend_input', recommendation)
      const recommendedPoint = curve.slice(1).find(x => x.spend_input === recommendedSpend) // don't recommend going all the way down
      const recommendedRevenue = recommendedPoint?.marginal_response || 0
      const recommendedRoas = findClosestPoint(curve, 'predicted_roas', recommendedRevenue / recommendedSpend)

      const optimalSpend = recommendedSpend
      const optimalRevenue = recommendedRevenue
      const optimalRoas = recommendedRoas

      // min and max spend on response curve
      const minSpend = _.min(curve, 'spend_input').spend_input
      const maxSpend = _.max(curve, 'spend_input').spend_input

      acc[featureKey] = {
        minRoas,
        maxRoas,
        minSpend,
        maxSpend,
        currentSpend,
        optimalSpend,
        optimalRevenue,
        optimalRoas,
        recommendedSpend,
        recommendedRevenue,
        recommendedRoas,
        historicalSpendMin,
        historicalSpendMax,
      }
      return acc
    }, {})
  }, [responseCurves, trainingStats])

  // run date of selected model
  const modelDate = useMemo(() => {
    if (!selectedModel) return
    return moment.unix(selectedModel).format('MMMM D, YYYY')
  }, [selectedModel])

  return {
    models,
    selectedModel,
    setSelectedModel,
    modelData,
    features,
    channels,
    responseCurves,
    featureStats,
    modelDate,
  }
}
