
import { DEFAULT_TAX_RATES, offsetGainsWithLosses, ZERO } from "@frec-js/common";
import Decimal from "decimal.js";
import { useMemo } from "react";

import {
  EstimateDirectIndexGainLossInput,
  EstimateDirectIndexGainLossQuery,
  useEstimateDirectIndexGainLossQuery,
  UserTaxRates,
} from "../generated/graphql";

const computeTaxRates = (userTaxRates: UserTaxRates) => {
  // income tax rate + state capital gains tax rate
  const shortTermGainsRate = userTaxRates.federalIncomeTaxRate.plus(
    userTaxRates.stateCapitalGainsTaxRate
  );

  // long term capital gains tax rate + state capital gains tax rate
  const longTermGainsRate = userTaxRates.federalCapitalGainsTaxRate.plus(
    userTaxRates.stateCapitalGainsTaxRate
  );

  return {
    shortTermGainsRate,
    longTermGainsRate,
  };
};

const computeTaxes = (
  shortTermGainsRate: Decimal,
  longTermGainsRate: Decimal,
  netShortTermGains: Decimal,
  netLongTermGains: Decimal
) => {
  const shortTermTaxes = netShortTermGains.times(shortTermGainsRate.div(100));
  const longTermTaxes = netLongTermGains.times(longTermGainsRate.div(100));
  const totalTaxes = shortTermTaxes.plus(longTermTaxes);
  return {
    shortTerm: shortTermTaxes,
    longTerm: longTermTaxes,
    totalTaxes,
  };
};

export type TermEstimatedTaxImpact = {
  taxRate: Decimal;
  gains: Decimal;
  netGains: Decimal;
  taxes: Decimal;
  washSales: Decimal;
  netLosses: Decimal;
  lossHarvestPotential: Decimal;
};

export type ComputedEstimatedTaxImpact = {
  shortTerm: TermEstimatedTaxImpact;
  longTerm: TermEstimatedTaxImpact;
  totalTaxes: Decimal;
};

const _computeEstimatedTaxImpact = (
  estimateDirectIndexGainLoss: EstimateDirectIndexGainLossQuery["estimateDirectIndexGainLoss"],
  userTaxRates: UserTaxRates,
  lossHarvestPotentialEnabled: boolean
): ComputedEstimatedTaxImpact => {
  const {
    totalShortTermGainsWithoutWashSales,
    totalLongTermGainsWithoutWashSales,
    totalShortTermWashSales,
    totalLongTermWashSales,
    lossHarvestPotential,
  } = estimateDirectIndexGainLoss;

  const lossHarvestPotentialForCalculations = lossHarvestPotentialEnabled
    ? lossHarvestPotential
    : ZERO;

  const { shortTermGainsRate, longTermGainsRate } =
    computeTaxRates(userTaxRates);

  // The net losses (loss harvest potential - wash sales) used to offset gains
  let netShortTermLosses = lossHarvestPotentialForCalculations
    .minus(totalShortTermWashSales);
  let netLongTermLosses = totalLongTermWashSales.negated();
  // Prevent negative zeros
  if (netShortTermLosses.isZero()) {
    netShortTermLosses = ZERO;
  }
  if (netLongTermLosses.isZero()) {
    netLongTermLosses = ZERO;
  }

  // Offset gains with losses
  const { shortTermGains, longTermGains } = offsetGainsWithLosses(
    totalShortTermGainsWithoutWashSales.plus(totalShortTermWashSales),
    lossHarvestPotentialForCalculations,
    totalLongTermGainsWithoutWashSales.plus(totalLongTermWashSales),
    ZERO,
  );

  // Compute taxes
  const {
    shortTerm: shortTermTaxes,
    longTerm: longTermTaxes,
    totalTaxes,
  } = computeTaxes(
    shortTermGainsRate,
    longTermGainsRate,
    shortTermGains,
    longTermGains
  );

  return {
    shortTerm: {
      taxRate: shortTermGainsRate,
      gains: totalShortTermGainsWithoutWashSales.plus(totalShortTermWashSales),
      netGains: shortTermGains,
      taxes: shortTermTaxes,
      washSales: totalShortTermWashSales,
      netLosses: netShortTermLosses,
      lossHarvestPotential, // short term only
    },
    longTerm: {
      taxRate: longTermGainsRate,
      gains: totalLongTermGainsWithoutWashSales.plus(totalLongTermWashSales),
      netGains: longTermGains,
      taxes: longTermTaxes,
      washSales: totalLongTermWashSales,
      netLosses: netLongTermLosses,
      lossHarvestPotential: ZERO,
    },
    totalTaxes: totalTaxes,
  };
};

const computeEstimatedTaxImpact = (
  { estimateDirectIndexGainLoss }: EstimateDirectIndexGainLossQuery,
  userTaxRates: UserTaxRates
): {
  withLHP: ComputedEstimatedTaxImpact;
  withoutLHP: ComputedEstimatedTaxImpact;
} => {
  const withLHP = _computeEstimatedTaxImpact(
    estimateDirectIndexGainLoss,
    userTaxRates,
    true
  );
  const withoutLHP = _computeEstimatedTaxImpact(
    estimateDirectIndexGainLoss,
    userTaxRates,
    false
  );
  return {
    withLHP,
    withoutLHP,
  };
};

const INITIAL_TERM_ESTIMATED_TAX_IMPACT: TermEstimatedTaxImpact = {
  taxRate: ZERO,
  gains: ZERO,
  netGains: ZERO,
  taxes: ZERO,
  washSales: ZERO,
  netLosses: ZERO,
  lossHarvestPotential: ZERO,
};

const INITIAL_ESTIMATED_TAX_IMPACT: ComputedEstimatedTaxImpact = {
  shortTerm: INITIAL_TERM_ESTIMATED_TAX_IMPACT,
  longTerm: INITIAL_TERM_ESTIMATED_TAX_IMPACT,
  totalTaxes: ZERO,
};

export const useEstimatedTaxImpact = (
  input: EstimateDirectIndexGainLossInput,
  lossHarvestPotentialEnabled: boolean,
  userTaxRates?: UserTaxRates,
  skip?: boolean
) => {
  const { data, loading, error, refetch } = useEstimateDirectIndexGainLossQuery(
    {
      variables: {
        input,
      },
      notifyOnNetworkStatusChange: true,
      skip,
    }
  );

  return useMemo(() => {
    // Initial state is zero
    const initial = {
      loading,
      error,
      refetch,
      estimatedTaxImpact: INITIAL_ESTIMATED_TAX_IMPACT,
      showLossHarvestPotential: false,
      noTaxImpact: false,
    };
    
    if (skip) {
      return { ...initial, noTaxImpact: true };
    }

    if (loading || !data) {
      return initial;
    }

    const { withLHP, withoutLHP } = computeEstimatedTaxImpact(
      data,
      userTaxRates ?? DEFAULT_TAX_RATES as UserTaxRates
    );
    const estimatedTaxImpact = lossHarvestPotentialEnabled
      ? withLHP
      : withoutLHP;

    // All the data available to compute the estimated tax impact
    return {
      ...initial,
      estimatedTaxImpact,
      // If there is no LHP, don't show it at all
      showLossHarvestPotential:
        !estimatedTaxImpact.shortTerm.lossHarvestPotential.isZero(),
      // If there is no tax impact when LHP is not enabled, don't show details
      noTaxImpact: withoutLHP.totalTaxes.isZero(),
    };
  }, [
    loading,
    data,
    userTaxRates,
    error,
    refetch,
    lossHarvestPotentialEnabled,
    skip,
  ]);
};
