import Decimal from "decimal.js";

import { RealizedIndicatorEnum } from "../common";
import { SHORT_TERM_CAPITAL_GAIN_LIQUIDATION_PENALTY } from "../constants";
import { BUSINESS_TIMEZONE, DateOnly } from "../date_utils";
import { OrderPositionType, OrderSide } from "../generated/graphql";
import {
  computeGainsForLots,
  computeTaxLotEntry,
  getMinTaxCompareTo,
  getUnrealizedCandidates,
  LotMatchingInstructionsInputArgs,
  OrderLotInputArgs,
  TaxLotWithId,
} from "../portfolio_utils/taxLotUtils";
import { get, testUUID, ZERO } from "../utils";
import { DirectIndexStockInfoDO } from "./directIndexModel";

export type StockGainInfo = {
  symbol: string;
  longTermGainDelta: Decimal;
  shortTermGainDelta: Decimal;
  newWashSale: boolean;
};

// holdings and each tax lot may have a difference of 0.00001
// We allow max 6 lots to differ by 0.00001
// support.apexclearing.com/hc/requests/395782
// // 3FR05398 | ESGD has 0.00008 difference, move back to 0.00005 once resolved.
export const HOLDING_TAX_LOT_MISMATCH_TOLERANCE = new Decimal(0.00008);

/**
 * Approximates the results of applying a direct index delta response to a
 * portfolio. Sells according to LTFO policy and uses the tax lot software to
 * determine wash sales etc.
 *
 * TODO: this currently does not work because StockInfoDO does not have
 * realizedGains for realized lots, which prevents us from computing the wash
 * sale disallowed amount.
 */
export const applyDirectIndexDelta = (
  stockInfoInput: DirectIndexStockInfoDO[],
  quantityDelta: { delta: Decimal; symbol: string }[],
  currentDate: DateOnly,
  subAccountId: string,
) => {
  let cashDelta = ZERO;
  const newStockInfo: DirectIndexStockInfoDO[] = [];
  const stockGainInfo: StockGainInfo[] = [];
  for (const info of stockInfoInput) {
    const securityIds = new Set(info.taxLots.map((t) => t.securityId));
    if (securityIds.size > 1) {
      throw new Error(
        `Cannot apply direct index delta to stock with multiple security ids: ${info.symbol}`,
      );
    }

    // Match the security ids, or make one up
    const securityId = info.taxLots[0]?.securityId ?? info.symbol;
    const qd = quantityDelta.find((q) => q.symbol === info.symbol);

    // IMPORTANT TO ROUND TO 5 DECIMAL PLACES
    const delta = qd?.delta.toDP(5, Decimal.ROUND_DOWN);
    if (delta && delta.abs().gt(0)) {
      const orderInput: OrderInput = {
        subAccountId,
        securityId,
        orderSide: delta.gt(0) ? OrderSide.Buy : OrderSide.Sell,
        orderPositionType: OrderPositionType.Long,
        eventTime: currentDate,
        quantity: delta.abs(),
        sharePrice: info?.price ?? ZERO,
      };
      cashDelta = cashDelta.sub(delta.times(info.price));
      const { newTaxLots } = computeTaxLotsForOrder(
        orderInput,
        info.taxLots,
        [],
        "LTFO",
      );
      const qBefore = getLotQuantity(info.taxLots);
      const qAfter = getLotQuantity(newTaxLots);
      const qDelta = qAfter.sub(qBefore);
      const actualCashDelta = qDelta.times(info.price);
      if (Decimal.abs(actualCashDelta.minus(delta.times(info.price))).gt(0.5)) {
        throw new Error(
          `Mismatched cash delta for ${
            info.symbol
          }: ${actualCashDelta} vs ${delta.times(info.price)}`,
        );
      }
      const { shortTermGainDelta, longTermGainDelta, newWashSale } =
        computeGainDelta(info.taxLots, newTaxLots);

      if (
        !shortTermGainDelta.eq(0) ||
        !longTermGainDelta.eq(0) ||
        newWashSale
      ) {
        stockGainInfo.push({
          symbol: info.symbol,
          shortTermGainDelta,
          longTermGainDelta,
          newWashSale,
        });
      }

      newStockInfo.push({
        ...info,
        taxLots: newTaxLots,
      });
    } else {
      newStockInfo.push(info);
    }
  }

  return {
    cashDelta,
    stockInfo: newStockInfo,
    stockGainInfo,
  };
};

const getLotQuantity = (lot: TaxLotWithId[]) => {
  return lot
    .filter((t) => t.realizedIndicator === RealizedIndicatorEnum.Unrealized)
    .reduce((acc, lot) => acc.add(lot.quantity), ZERO);
};

export type OrderInput = {
  subAccountId: string;
  securityId: string;
  orderSide: OrderSide;
  orderPositionType: OrderPositionType;
  eventTime: DateOnly;
  quantity: Decimal;
  sharePrice: Decimal;
  lotMatchingId?: string;
};

export const getLTFOLotMatchingInstructions = (
  orderInput: OrderInput,
  newTaxLots: TaxLotWithId[],
) => {
  const taxLotLookUp = new Map(newTaxLots.map((taxLot) => [taxLot.id, taxLot]));

  const lotsToSell = getUnrealizedCandidates(
    {
      subAccountId: orderInput.subAccountId,
      securityId: orderInput.securityId,
      quantity: orderInput.quantity,
    },
    newTaxLots,
    getMinTaxCompareTo(
      orderInput.sharePrice,
      orderInput.eventTime,
      SHORT_TERM_CAPITAL_GAIN_LIQUIDATION_PENALTY,
    ),
  );

  return lotsToSell.map((lot) => {
    const taxLot = get(taxLotLookUp.get(lot.taxLotEntryId));
    return {
      tradeDate: get(taxLot.taxLotOpenBuyDate),
      quantity: lot.sellQuantity,
      taxLotEntryId: lot.taxLotEntryId,
    };
  });
};

export const computeTaxLotsForOrder = (
  orderInput: OrderInput,
  existingTaxLots: TaxLotWithId[],
  existingDeletedTaxLots: TaxLotWithId[],
  sellOrder: "FIFO" | "LTFO",
): {
  newTaxLots: TaxLotWithId[];
  newDeletedTaxLots: TaxLotWithId[];
  taxLotValidationErrors: string[];
} => {
  // Filter to specific subAccount id and security
  const filteredExistingTaxLots = existingTaxLots.filter(
    (t) => t.securityId === orderInput.securityId,
  );
  const filteredExistingDeletedTaxLots = existingDeletedTaxLots.filter(
    (t) => t.securityId === orderInput.securityId,
  );
  const errors: string[] = [];

  // IMPORTANT TO ROUND TO 5 DECIMAL PLACES
  orderInput.quantity = orderInput.quantity.toDP(5, Decimal.ROUND_DOWN);
  const lotMatchingInstructions: LotMatchingInstructionsInputArgs[] = [];
  if (orderInput.orderSide === OrderSide.Sell) {
    if (orderInput.lotMatchingId) {
      const lot = filteredExistingTaxLots.find(
        (t) =>
          t.openLotId === orderInput.lotMatchingId &&
          t.subAccountId == orderInput.subAccountId &&
          t.realizedIndicator === RealizedIndicatorEnum.Unrealized,
      );

      if (lot && lot.quantity.gte(orderInput.quantity)) {
        lotMatchingInstructions.push({
          quantity: orderInput.quantity,
          taxLotEntryId: lot.id,
          tradeDate: DateOnly.fromDateTz(lot.eventTime, BUSINESS_TIMEZONE),
          price: lot.openBuyPrice,
        });
      } else {
        errors.push(
          `Cannot find lot id matching order id ${orderInput.lotMatchingId}`,
        );
      }
    } else if (sellOrder === "LTFO") {
      lotMatchingInstructions.push(
        ...getLTFOLotMatchingInstructions(
          orderInput,
          filteredExistingTaxLots.filter(
            (tl) => tl.subAccountId === orderInput.subAccountId,
          ),
        ),
      );
    }
  }

  const args: OrderLotInputArgs = {
    subAccountId: orderInput.subAccountId,
    securityId: orderInput.securityId, // must match tax lots
    side: orderInput.orderSide ?? OrderSide.Buy,
    positionType: orderInput.orderPositionType ?? OrderPositionType.Long,
    quantity: orderInput.quantity,
    sharePrice: orderInput.sharePrice,
    notional: orderInput.quantity.times(orderInput.sharePrice),
    eventTime: orderInput.eventTime.toDateSpecificTime(
      BUSINESS_TIMEZONE,
      12,
      0,
    ),
    lotMatchingInstructions,
  };

  const quantityBefore = getLotQuantity(filteredExistingTaxLots);
  const expectedQuantityAfter =
    orderInput.orderSide === OrderSide.Sell
      ? quantityBefore.sub(orderInput.quantity)
      : quantityBefore.add(orderInput.quantity);

  const {
    createLots: createLotsRaw,
    deleteLotIds,
    errors: lotErrors,
  } = computeTaxLotEntry(
    args,
    // To ensure lots from other subAccount lots are passed too
    existingTaxLots.filter((l) => l.securityId == orderInput.securityId),
  );
  if (errors.length > 0) {
    // Adding this to improve chances of catching silent fails
    throw new Error(errors.join("\n"));
  }

  const deletedLots = filteredExistingTaxLots.filter((t) =>
    deleteLotIds.includes(t.id),
  );
  const lotsToKeep = filteredExistingTaxLots.filter(
    (t) => !deleteLotIds.includes(t.id),
  );
  const createLots: TaxLotWithId[] = createLotsRaw.map((t) => ({
    ...t,
    id: testUUID(),
  }));

  const newTaxLots = [...lotsToKeep, ...createLots];
  const newDeletedTaxLots = [...filteredExistingDeletedTaxLots, ...deletedLots];

  const quantityAfter = getLotQuantity(newTaxLots);
  // This may not match due to apex rounding issues, we allow a bit of tolerance
  if (
    quantityAfter
      .sub(expectedQuantityAfter)
      .abs()
      .gt(HOLDING_TAX_LOT_MISMATCH_TOLERANCE)
  ) {
    throw new Error(
      `Quantity mismatch after applying order (securityId ${orderInput.securityId}), expected ${expectedQuantityAfter} but got ${quantityAfter}`,
    );
  }

  return {
    newTaxLots,
    newDeletedTaxLots,
    taxLotValidationErrors: [
      ...errors,
      ...Array.from(lotErrors ?? []).map(String),
    ],
  };
};

/**
 * Given a list of order inputs and existing tax lots, compute the new tax lots,
 * this method is meant to be used for a single security ID with multiple orders.
 */
export const computeTaxLotsMultipleOrdersSameSecurity = (
  orderInputLot: OrderInput[],
  existingTaxLots: TaxLotWithId[],
  existingDeletedTaxLots: TaxLotWithId[],
  sellOrder: "FIFO" | "LTFO",
): {
  newTaxLots: TaxLotWithId[];
  newDeletedTaxLots: TaxLotWithId[];
  taxLotValidationErrors: string[];
} => {
  const securityIds = new Set([
    ...orderInputLot.map((o) => o.securityId),
    ...existingTaxLots.map((t) => t.securityId),
    ...existingDeletedTaxLots.map((t) => t.securityId),
  ]);

  if (securityIds.size > 1) {
    throw new Error(
      `Cannot compute tax lots for multiple security ids: ${Array.from(
        securityIds,
      )}`,
    );
  }

  let newTaxLots: TaxLotWithId[] = [...existingTaxLots];
  let newDeletedTaxLots: TaxLotWithId[] = [...existingDeletedTaxLots];
  let errors: string[] = [];

  orderInputLot
    .sort((a, b) => a.eventTime.compare(b.eventTime))
    .forEach((orderInput) => {
      const {
        newTaxLots: newTaxLotsRaw,
        newDeletedTaxLots: newDeletedTaxLotsRaw,
        taxLotValidationErrors,
      } = computeTaxLotsForOrder(
        orderInput,
        newTaxLots,
        newDeletedTaxLots,
        sellOrder,
      );

      newTaxLots = newTaxLotsRaw;
      newDeletedTaxLots = newDeletedTaxLotsRaw;
      errors = errors.concat(taxLotValidationErrors);
    });
  return {
    newTaxLots,
    newDeletedTaxLots,
    taxLotValidationErrors: errors,
  };
};

export const computeGainDelta = (
  oldTaxLots: TaxLotWithId[],
  newTaxLots: TaxLotWithId[],
) => {
  const oldGains = computeGainsForLots(oldTaxLots);
  const newGains = computeGainsForLots(newTaxLots);
  const oldWashSaleAmount = oldTaxLots.reduce(
    (acc, lot) => acc.add(lot.washSalesDisallowed),
    ZERO,
  );
  const newWashSaleAmount = newTaxLots.reduce(
    (acc, lot) => acc.add(lot.washSalesDisallowed),
    ZERO,
  );
  const shortTermGainDelta = newGains.shortTermGains.minus(
    oldGains.shortTermGains,
  );
  const longTermGainDelta = newGains.longTermGains.minus(
    oldGains.longTermGains,
  );
  const newWashSale = newWashSaleAmount.gt(oldWashSaleAmount);

  return {
    shortTermGainDelta,
    longTermGainDelta,
    newWashSale,
  };
};

/**
 * Computes the resulting gains after offsetting them with losses.
 * Applies short term losses to short term gains first, then long term losses to long term gains.
 * Remaining short term losses are then applied to long term gains.
 * Remaining long term losses are then applied to short term gains.
 */
export const offsetGainsWithLosses = (
  stGains: Decimal,
  stLosses: Decimal,
  ltGains: Decimal,
  ltLosses: Decimal,
) => {
  let remainingStGains = stGains;
  let remainingLtGains = ltGains;
  let remainingStLosses = stLosses;
  let remainingLtLosses = ltLosses;

  // 1. Short term losses offset short term gains first
  // If short term gains are already negative, don't apply short term losses
  if (stGains.gt(0)) {
    const stOffset = stGains.minus(stLosses);
    if (stOffset.gte(0)) {
      remainingStGains = stOffset;
      remainingStLosses = ZERO;
    } else {
      remainingStGains = ZERO;
      remainingStLosses = stOffset.abs();
    }
  }

  // 2. Long term losses offset long term gains first
  // If long term gains are already negative, don't apply long term losses
  if (ltGains.gt(0)) {
    const ltOffset = ltGains.minus(ltLosses);
    if (ltOffset.gte(0)) {
      remainingLtGains = ltOffset;
      remainingLtLosses = ZERO;
    } else {
      remainingLtGains = ZERO;
      remainingLtLosses = ltOffset.abs();
    }
  }

  // 3. Remaining long term losses offset short term gains
  if (remainingStGains.gt(0)) {
    remainingStGains = Decimal.max(
      0,
      remainingStGains.minus(remainingLtLosses),
    );
  }

  // 4. Remaining short term losses offset long term gains
  if (remainingLtGains.gt(0)) {
    remainingLtGains = Decimal.max(
      0,
      remainingLtGains.minus(remainingStLosses),
    );
  }

  // Use any remaining losses to offset gains
  if (remainingLtGains.lt(0) && remainingStGains.gt(0)) {
    const offset = remainingStGains.plus(remainingLtGains);
    remainingStGains = Decimal.max(0, offset);
    remainingLtGains = offset.minus(remainingStGains);
  }

  if (remainingStGains.lt(0) && remainingLtGains.gt(0)) {
    const offset = remainingStGains.plus(remainingLtGains);
    remainingLtGains = Decimal.max(0, offset);
    remainingStGains = offset.minus(remainingLtGains);
  }

  return {
    shortTermGains: remainingStGains,
    longTermGains: remainingLtGains,
  };
};

/**
 * Offsets gains with losses from optimizer first, then apply loss harvest potential.
 * The gains should include wash sales.
 */
export const offsetEstimatedTaxImpact = (
  stGains: Decimal,
  ltGains: Decimal,
  lossHarvestPotential: Decimal,
) => {
  // Offset gains with losses first
  if (stGains.lt(0) && ltGains.gt(0)) {
    const offset = stGains.plus(ltGains);
    ltGains = Decimal.max(0, offset);
    stGains = offset.minus(ltGains);
  }
  if (ltGains.lt(0) && stGains.gt(0)) {
    const offset = stGains.plus(ltGains);
    stGains = Decimal.max(0, offset);
    ltGains = offset.minus(stGains);
  }

  // Apply loss harvest potential
  if (ltGains.lte(0)) {
    // At this point, stGains is either 0 or negative. This ensures both stGains and ltGains are the same sign.
    return {
      shortTermGains: stGains.minus(lossHarvestPotential),
      longTermGains: ltGains,
    };
  }

  // General case
  stGains = stGains.minus(lossHarvestPotential);
  if (stGains.lt(0)) {
    const offset = ltGains.plus(stGains);
    ltGains = Decimal.max(0, offset);
    stGains = offset.minus(ltGains);
  }

  return {
    shortTermGains: stGains,
    longTermGains: ltGains,
  };
};
