import { ScaleTime, Selection, EnterElement } from "d3"
import { seriesCanvasLine } from "d3fc"
import { ModalityGraphGroupReactCallbacks } from "../../../../Types/ReactCallbacks"
import { TimeSeriesPageManager } from "../../../../Data/TimeSeriesPageManager"
import { TraceConfig } from "../../../../Types/Trace"
import { ModalityPage } from "../../../../Data/ModalityPage"
import { D3OneToOneRenderable } from "../../../D3/D3OneToOneRenderable"
import { TimeSeriesData, getTimeSeriesDataAccessor } from "../../../../Data/TimeSeriesData"

export class D3Trace extends D3OneToOneRenderable<SVGGElement, SVGForeignObjectElement, TraceConfig, ModalityGraphGroupReactCallbacks> {
	private series = seriesCanvasLine()
	private offscreenRenderer = seriesCanvasLine()
	private context?: CanvasRenderingContext2D
	private pageManager: TimeSeriesPageManager<ModalityPage>
	private previousCanvasSnapshot: { startDate: Date; endDate: Date; bitmap: ImageBitmap } | null = null
	private offscreenCanvas: OffscreenCanvas
	private offscreenXScale: ScaleTime<any, any, any>
	private pageRectangle = { x: 0, y: 0, width: 0, height: 0 }
	private dataAccessor

	constructor(root: SVGGElement, config: TraceConfig, pageManager: TimeSeriesPageManager<ModalityPage>, reactCallbacks: ModalityGraphGroupReactCallbacks) {
		super(root, config, "d3-trace", reactCallbacks)
		this.pageManager = pageManager
		this.dataAccessor = getTimeSeriesDataAccessor(config)

		const width = config.xScale.range()[1]
		const height = config.yScale.range()[0]
		this.offscreenCanvas = new OffscreenCanvas(width, height)

		this.offscreenXScale = config.xScale.copy().range([0, width])
		this.offscreenRenderer = seriesCanvasLine()
			.xScale(this.offscreenXScale)
			.yScale(this.config.yScale)
			.crossValue((p: [number, number]) => p[0])
			.mainValue(this.dataAccessor)
			.context(this.offscreenCanvas.getContext("2d"))
			.decorate((context: CanvasRenderingContext2D) => {
				context.strokeStyle = this.config.color
			})

		this.render()
	}

	public renderPage = (page: ModalityPage | undefined) => {
		if (!page || !this.context) {
			return
		}

		const { x, y, width, height } = this.getPageRectangle(page.startTime, page.endTime)
		this.context.clearRect(x, y, width, height)

		// Check if there is a cached render at the same resolution
		const cachedTraceRender = page.renderCache.get(this.getRenderCacheKey())
		
		if (cachedTraceRender && this.context.canvas.width === cachedTraceRender.width && this.context.canvas.height === cachedTraceRender.height) {
			this.context.drawImage(cachedTraceRender, x, y)
			return
		}

		// If not, use the data to redraw
		const traceData = page.data.get(this.config.dataKey)

		if (traceData) {
			this.offscreenXScale.domain([page.startTime, page.endTime]).range([0, width])
			this.renderDataWithoutPageBoundaries(traceData, page, this.offscreenRenderer)
			const bitmap = this.offscreenCanvas.transferToImageBitmap()
			page.renderCache.set(this.getRenderCacheKey(), bitmap)
			this.context.drawImage(bitmap, x, y)
		} else if (!page.loaded) {
			this.context.fillStyle = "lightgray"
			this.context.fillRect(x, y, width, height)
		}

		this.fillInGapBetweenPages()
	}

	public isRescaling = () => this.previousCanvasSnapshot !== null

	public rescale = () => {
		if (this.context && this.previousCanvasSnapshot) {
			this.pageManager.getAllLoadedPages().forEach(page => {
				page?.renderCache.delete(this.getRenderCacheKey())
			})

			this.context.clearRect(0, 0, this.context.canvas.width, this.context.canvas.height)
			const { x, y, width, height } = this.getPageRectangle(this.previousCanvasSnapshot.startDate, this.previousCanvasSnapshot.endDate)
			this.context.drawImage(this.previousCanvasSnapshot.bitmap, x, y, width, height)
			this.fillInGapBetweenPages()
		}
	}

	public render = () => {
		if (this.previousCanvasSnapshot) {
			this.rescale()
			return
		}

		super.render()
	}

	public redraw = () => {
		if (this.context) {
			this.pageManager.getPagesInView().forEach(page => {
				if (page) {
					const { x, y, width, height } = this.getPageRectangle(page.startTime, page.endTime)
					this.context?.clearRect(x, y, width, height)
					const data = page.data.get(this.config.dataKey)
					if (data) {
						this.renderDataWithoutPageBoundaries(data, page, this.series)
					}
				}
			})
			this.fillInGapBetweenPages()
		}

		this.pageManager.getAllLoadedPages().forEach(page => {
			page?.renderCache.delete(this.getRenderCacheKey())
		})
	}

	public takeSnapshot = async () => {
		if (this.context) {
			const bitmap = await createImageBitmap(this.context.canvas)
			const [startDate, endDate] = this.config.xScale.domain()
			this.previousCanvasSnapshot = { startDate, endDate, bitmap }
		}
	}

	public clearSnapshot = () => {
		this.previousCanvasSnapshot = null
	}

	private getRenderCacheKey = () => {
		return `${this.config.graphId}-${this.config.dataKey}`
	}

	private getPageRectangle = (startTime: Date | number, endTime: Date | number) => {
		const x1 = this.config.xScale(startTime)
		const x2 = this.config.xScale(endTime)
		const width = x2 - x1
		const height = this.config.yScale.range()[0]

		// rounding helps the render cache be extra fast but reduces quality
		this.pageRectangle.x = x1
		this.pageRectangle.width = width
		this.pageRectangle.height = height

		return this.pageRectangle
	}

	protected updateDerivedState = () => {
		this.series
			.xScale(this.config.xScale)
			.yScale(this.config.yScale)
			.decorate((context: CanvasRenderingContext2D) => {
				context.strokeStyle = this.config.color
			})

		this.offscreenRenderer.decorate((context: CanvasRenderingContext2D) => {
			context.strokeStyle = this.config.color
		})

		this.offscreenCanvas.width = this.config.xScale.range()[1]
		this.offscreenCanvas.height = this.config.yScale.range()[0]

		this.pageManager.getAllLoadedPages().forEach(page => {
			page?.renderCache.delete(this.getRenderCacheKey())
		})
	}

	protected enter = (newTrace: Selection<EnterElement, any, any, any>): Selection<SVGForeignObjectElement, any, any, any> => {
		const foreignObject = newTrace.append("foreignObject").attr("class", this.className)
		const canvas = foreignObject.append("xhtml:canvas") as Selection<HTMLCanvasElement, any, any, any>

		const width = this.config.xScale?.range()[1] - this.config.xScale?.range()[0]
		const height = this.config.yScale?.range()[0] - this.config.yScale?.range()[1]

		foreignObject.attr("width", width).attr("height", height)
		canvas.attr("width", width).attr("height", height)

		const canvasNode = canvas.node()

		if (canvasNode != null) {
			this.context = canvasNode.getContext("2d") ?? undefined

			this.series = seriesCanvasLine()
				.xScale(this.config.xScale)
				.yScale(this.config.yScale)
				.crossValue((p: [number, number]) => p[0])
				.mainValue(this.dataAccessor)
				.context(this.context)
				.decorate((context: CanvasRenderingContext2D) => {
					context.strokeStyle = this.config.color
				})
		}

		return foreignObject
	}

	protected update = (updatedForeignObject: Selection<any, any, any, any>): Selection<any, any, any, any> => {
		const updatedCanvas = updatedForeignObject.select("canvas")

		const width = this.config.xScale?.range()[1] - this.config.xScale?.range()[0]
		const height = this.config.yScale?.range()[0] - this.config.yScale?.range()[1]

		updatedForeignObject.attr("width", width).attr("height", height)
		updatedCanvas.attr("width", width).attr("height", height)

		this.pageManager.getPagesInView().forEach(page => this.renderPage(page))
		this.fillInGapBetweenPages()

		return updatedForeignObject
	}

	private fillInGapBetweenPages() {
		const pagesInView = this.pageManager.getPagesInView()

		if (pagesInView.length < 2 || pagesInView.some(page => !page?.loaded) || this.isRescaling()) {
			return
		}

		const pageBeforeViewData = this.pageManager.getPageBeforeView()?.data.get(this.config.dataKey)
		const currentPageData = pagesInView[0]?.data.get(this.config.dataKey)
		const nextPageData = pagesInView[1]?.data.get(this.config.dataKey)
		const pageAfterViewData = this.pageManager.getPageAfterView()?.data.get(this.config.dataKey)

		this.series(this.calculateGapFill(currentPageData, nextPageData))
		this.series(this.calculateGapFill(pageBeforeViewData, currentPageData))
		this.series(this.calculateGapFill(nextPageData, pageAfterViewData))
	}

	private renderDataWithoutPageBoundaries(traceData: TimeSeriesData, page: ModalityPage, renderer: (traceData: TimeSeriesData) => void) {
		const firstPoint = traceData[0]
		const lastPoint = traceData[traceData.length-1]

		traceData[0] = [undefined, undefined] as any
		traceData[traceData.length-1] = [undefined, undefined] as any

		renderer(traceData)

		traceData[0] = firstPoint
		traceData[traceData.length-1] = lastPoint
	}

	private calculateGapFill(firstPageData: TimeSeriesData | undefined, secondPageData: TimeSeriesData | undefined) : ([number, number] | [undefined, undefined])[] {
		const pageBoundaryPreviousPoint = firstPageData?.at(-1)
		const pageBoundaryNextPoint = secondPageData?.at(0)
		const pageBoundaryTimestamp = pageBoundaryPreviousPoint?.at(0) ?? pageBoundaryNextPoint?.at(0)

		if (!pageBoundaryTimestamp) {
			return []
		}

		let averageValue

		if (pageBoundaryPreviousPoint) {
			averageValue = this.dataAccessor(pageBoundaryPreviousPoint)
		}

		if (averageValue !== undefined && pageBoundaryNextPoint) {
			averageValue += this.dataAccessor(pageBoundaryNextPoint)
			averageValue /= 2
		}

		return [firstPageData?.at(-2), [pageBoundaryTimestamp, averageValue], secondPageData?.at(1)].filter(p => p !== undefined) as any
	}
}
