import vtkDataArray from '@kitware/vtk.js/Common/Core/DataArray';
import vtkPoints from '@kitware/vtk.js/Common/Core/Points';
import vtkPolyData from '@kitware/vtk.js/Common/DataModel/PolyData';
import {mat4, vec3} from 'gl-matrix';

import {
	type ActorEntry,
	cornerstone,
	type PointCloudColorScheme,
} from '@/library';
import {
	bones,
	type Bone,
	type Coordinate,
	type DigitalTwinPolygonType,
	type ReamDigitalTwinObjectActorEntry,
	type ResectionPlaneKey,
	type Rotation,
	type Vector3,
	type VtkStateRef,
	type VtkState,
} from '@/types';
import {calculateGlobalRange, makeLut} from '@/utils';

const getPointCloudActorEntryAndBackup = ({
	bone,
	vtkState,
}: {
	bone: Bone;
	vtkState: VtkStateRef;
}): {actorEntry: ActorEntry; backupActorEntry: ActorEntry} => {
	const {
		femoralActor,
		femoralActorBackup,
		fibularActor,
		fibularActorBackup,
		patellarActor,
		patellarActorBackup,
		tibialActor,
		tibialActorBackup,
	} = getPointCloudAndBackupActors(vtkState);

	const boneToActorEntry: Record<Bone, ActorEntry> = {
		femur: femoralActor,
		fibula: fibularActor,
		patella: patellarActor,
		tibia: tibialActor,
	};

	const boneToActorEntryBackup: Record<Bone, ActorEntry> = {
		femur: femoralActorBackup,
		fibula: fibularActorBackup,
		patella: patellarActorBackup,
		tibia: tibialActorBackup,
	};

	return {
		actorEntry: boneToActorEntry[bone],
		backupActorEntry: boneToActorEntryBackup[bone],
	};
};

const getPointCloudAndBackupActors = (vtkState: VtkStateRef) => {
	const femoralActor = vtkState.current.pointCloudActors.find(
		(pointcloud) => pointcloud.id === 'femur_pointcloud',
	);

	const femoralActorBackup = vtkState.current.pointCloudActorsBackup.find(
		(pointcloud) => pointcloud.id === 'femur_pointcloud',
	);

	const tibialActor = vtkState.current.pointCloudActors.find(
		(pointcloud) => pointcloud.id === 'tibia_pointcloud',
	);

	const tibialActorBackup = vtkState.current.pointCloudActorsBackup.find(
		(pointcloud) => pointcloud.id === 'tibia_pointcloud',
	);

	const patellarActor = vtkState.current.pointCloudActors.find(
		(pointcloud) => pointcloud.id === 'patella_pointcloud',
	);

	const patellarActorBackup = vtkState.current.pointCloudActorsBackup.find(
		(pointcloud) => pointcloud.id === 'patella_pointcloud',
	);

	const fibularActor = vtkState.current.pointCloudActors.find(
		(pointcloud) => pointcloud.id === 'fibula_pointcloud',
	);

	const fibularActorBackup = vtkState.current.pointCloudActorsBackup.find(
		(pointcloud) => pointcloud.id === 'fibula_pointcloud',
	);

	if (!femoralActor) {
		throw new Error('Femoral actor not found');
	}

	if (!femoralActorBackup) {
		throw new Error('Femoral actor backup not found');
	}

	if (!tibialActor) {
		throw new Error('Tibial actor not found');
	}

	if (!tibialActorBackup) {
		throw new Error('Tibial actor backup not found');
	}

	if (!patellarActor) {
		throw new Error('Patellar actor not found');
	}

	if (!patellarActorBackup) {
		throw new Error('Patellar actor backup not found');
	}

	if (!fibularActor) {
		throw new Error('Fibular actor not found');
	}

	if (!fibularActorBackup) {
		throw new Error('Fibular actor backup not found');
	}

	return {
		femoralActor,
		femoralActorBackup,
		fibularActor,
		fibularActorBackup,
		patellarActor,
		patellarActorBackup,
		tibialActor,
		tibialActorBackup,
	};
};

function restoreActorFromBackup({
	actorEntry,
	backupActorEntry,
}: {
	actorEntry: ActorEntry;
	backupActorEntry: ActorEntry;
}) {
	if (!actorEntry.actor?.getMapper()?.getInputData()) {
		throw new Error(`Missing properties on actorEntry`);
	}

	if (!backupActorEntry.actor?.getMapper()?.getInputData()) {
		throw new Error(`Missing properties on backupActorEntry`);
	}

	const originalPolyData = backupActorEntry.actor.getMapper()?.getInputData();
	const clonedPolyData = vtkPolyData.newInstance();
	clonedPolyData.shallowCopy(originalPolyData);

	actorEntry.actor.getMapper()?.setInputData(clonedPolyData);
}

export function hidePointsInPointClouds({
	areDigitalTwinsVisible,
	areResectionPlanesVisible,
	vtkState,
}: {
	areDigitalTwinsVisible: boolean;
	areResectionPlanesVisible: boolean;
	vtkState: VtkStateRef;
}) {
	bones.forEach((bone) => {
		hidePointsInPointCloud({
			areDigitalTwinsVisible,
			areResectionPlanesVisible,
			bone,
			vtkState,
		});
	});
}

export function hidePointsInPointCloud({
	areDigitalTwinsVisible,
	areResectionPlanesVisible,
	bone,
	vtkState,
}: {
	areDigitalTwinsVisible: boolean;
	areResectionPlanesVisible: boolean;
	bone: Bone;
	vtkState: VtkStateRef;
}) {
	const {actorEntry, backupActorEntry} = getPointCloudActorEntryAndBackup({
		bone,
		vtkState,
	});
	restoreActorFromBackup({actorEntry, backupActorEntry});

	const boneLabelToResectionPlaneMap: Record<
		'femur' | 'tibia',
		ResectionPlaneKey
	> = {
		femur: 'femoral',
		tibia: 'tibial',
	};

	if (areResectionPlanesVisible) {
		vtkState.current.resectionPlaneActors
			.filter(
				(resectionPlaneActorEntry) =>
					resectionPlaneActorEntry.pair ===
						vtkState.current.selectedResectionPlanePair &&
					resectionPlaneActorEntry.plane ===
						boneLabelToResectionPlaneMap[actorEntry.label as 'femur' | 'tibia'],
			)
			.forEach((resectionPlaneActorEntry) => {
				const {origin, normal} = resectionPlaneActorEntry;
				const backupPolyData = actorEntry.actor.getMapper()?.getInputData();
				const backupPoints = backupPolyData.getPoints();
				const backupScalars = backupPolyData
					.getPointData()
					.getScalars()
					.getData();

				const {points, scalars} = findPointsAndScalarsRelativeToPlane({
					backupPoints,
					backupScalars,
					planeOrigin: origin,
					planeNormal: normal,
				});

				showOnlySelectedPoints({
					actorEntry,
					points,
					scalars,
				});
			});
	}

	if (areDigitalTwinsVisible) {
		if (vtkState.current.digitalTwinMode === 'remaining') {
			vtkState.current.digitalTwinActors?.forEach(
				(digitalTwinObjectActorEntry) => {
					if (digitalTwinObjectActorEntry.bone === bone) {
						const startingPolyData = actorEntry.actor
							.getMapper()
							?.getInputData();
						const startingPoints = startingPolyData.getPoints();
						const startingScalars = startingPolyData
							.getPointData()
							.getScalars()
							.getData();
						const {points, scalars} = findPointsAndScalarsRelativeToShape({
							shape: digitalTwinObjectActorEntry.type,
							center: digitalTwinObjectActorEntry.center,
							radius: (
								digitalTwinObjectActorEntry as ReamDigitalTwinObjectActorEntry
							).radius,
							startingPoints,
							startingScalars,
							filterPerspective: 'outside',
							...(digitalTwinObjectActorEntry.type === 'drill' && {
								height: digitalTwinObjectActorEntry.height,
								direction: digitalTwinObjectActorEntry.directions.y,
							}),
							...(digitalTwinObjectActorEntry.type === 'resect' && {
								depth: digitalTwinObjectActorEntry.depth,
								height: digitalTwinObjectActorEntry.height,
								width: digitalTwinObjectActorEntry.width,
								rotation: digitalTwinObjectActorEntry.rotation,
							}),
						});

						showOnlySelectedPoints({
							actorEntry,
							points,
							scalars,
						});
					}
				},
			);
		} else if (vtkState.current.digitalTwinMode === 'removed') {
			const uniquePointIdentifiers = new Set<string>();
			const accumulatedPoints: number[] = [];
			const accumulatedScalars: number[] = [];

			const startingPolyData = actorEntry.actor.getMapper()?.getInputData();
			const startingPoints = startingPolyData.getPoints();
			const startingScalars = startingPolyData
				.getPointData()
				.getScalars()
				.getData();

			vtkState.current.digitalTwinActors?.forEach(
				(digitalTwinObjectActorEntry) => {
					if (digitalTwinObjectActorEntry.bone === bone) {
						const {points, scalars} = findPointsAndScalarsRelativeToShape({
							shape: digitalTwinObjectActorEntry.type,
							center: digitalTwinObjectActorEntry.center,
							radius: (
								digitalTwinObjectActorEntry as ReamDigitalTwinObjectActorEntry
							).radius,
							startingPoints,
							startingScalars,
							filterPerspective: 'inside',
							...(digitalTwinObjectActorEntry.type === 'drill' && {
								height: digitalTwinObjectActorEntry.height,
								direction: digitalTwinObjectActorEntry.directions.y,
							}),
							...(digitalTwinObjectActorEntry.type === 'resect' && {
								depth: digitalTwinObjectActorEntry.depth,
								height: digitalTwinObjectActorEntry.height,
								width: digitalTwinObjectActorEntry.width,
								rotation: digitalTwinObjectActorEntry.rotation,
							}),
						});

						for (let i = 0; i < points.length; i += 3) {
							const pointStr = `${points[i]},${points[i + 1]},${points[i + 2]}`;

							if (!uniquePointIdentifiers.has(pointStr)) {
								uniquePointIdentifiers.add(pointStr);
								accumulatedPoints.push(points[i], points[i + 1], points[i + 2]);
								accumulatedScalars.push(scalars[i / 3]);
							}
						}
					}
				},
			);

			showOnlySelectedPoints({
				actorEntry,
				points: new Float32Array(accumulatedPoints),
				scalars: new Float32Array(accumulatedScalars),
			});
		}
	}
}

function findPointsAndScalarsRelativeToPlane({
	backupPoints,
	backupScalars,
	planeOrigin,
	planeNormal,
}: {
	backupPoints: any;
	backupScalars: any;
	planeOrigin: Vector3;
	planeNormal: Vector3;
}): {
	points: Float32Array;
	scalars: Float32Array;
} {
	const [x1, y1, z1] = planeOrigin;
	const [a, b, c] = planeNormal;
	const numberOfPoints = backupPoints.getNumberOfValues() / 3;

	const filteredPointsArray = new Float32Array(numberOfPoints * 3);
	const filteredScalarsArray = new Float32Array(numberOfPoints);
	let count = 0;

	for (let i = 0; i < numberOfPoints; i++) {
		const [x, y, z] = backupPoints.getPoint(i);
		const dotProduct = (x - x1) * a + (y - y1) * b + (z - z1) * c;
		const isVisible = dotProduct > 0;

		if (isVisible) {
			filteredPointsArray.set([x, y, z], count * 3);
			filteredScalarsArray[count] = backupScalars[i];
			count++;
		}
	}

	return {
		points: filteredPointsArray.subarray(0, count * 3),
		scalars: filteredScalarsArray.subarray(0, count),
	};
}

function transformPointRelativeToRotation(
	point: Vector3,
	center: Vector3,
	rotation: Rotation,
): Vector3 {
	// Create a rotation matrix
	const rotationMatrix = mat4.create();

	// Translate to the center of rotation (resect's center)
	mat4.translate(rotationMatrix, rotationMatrix, center as vec3);

	// Apply rotations with X-axis rotation negated
	mat4.rotateY(rotationMatrix, rotationMatrix, -rotation.y * (Math.PI / 180));
	mat4.rotateX(rotationMatrix, rotationMatrix, -rotation.x * (Math.PI / 180));
	mat4.rotateZ(rotationMatrix, rotationMatrix, -rotation.z * (Math.PI / 180));

	// Translate back
	mat4.translate(
		rotationMatrix,
		rotationMatrix,
		vec3.negate(vec3.create(), center as vec3),
	);

	// Transform the point
	const transformedPoint = vec3.create();
	vec3.transformMat4(transformedPoint, point as vec3, rotationMatrix);

	return Array.from(transformedPoint) as Vector3;
}

function findPointsAndScalarsRelativeToShape({
	center,
	depth,
	direction,
	filterPerspective = 'inside',
	height,
	startingPoints,
	startingScalars,
	radius,
	rotation,
	shape,
	width,
}: {
	center: Vector3;
	depth?: number;
	direction?: Coordinate;
	filterPerspective: 'inside' | 'outside';
	height?: number;
	startingPoints: any;
	startingScalars: any;
	radius: number;
	rotation?: Rotation;
	shape: DigitalTwinPolygonType;
	width?: number;
}): {
	points: Float32Array;
	scalars: Float32Array;
} {
	const numberOfPoints = startingPoints.getNumberOfValues() / 3;

	const filteredPointsArray = new Float32Array(numberOfPoints * 3);
	const filteredScalarsArray = new Float32Array(numberOfPoints);
	let count = 0;

	if (shape === 'drill' && height && direction) {
		const normalizedAxis = vec3.normalize(vec3.create(), direction as vec3);
		const halfHeight = height / 2;

		for (let i = 0; i < numberOfPoints; i++) {
			const point = startingPoints.getPoint(i) as Vector3;
			const vectorToPoint = vec3.fromValues(
				point[0] - center[0],
				point[1] - center[1],
				point[2] - center[2],
			);
			const projectionDistance = Math.abs(
				vec3.dot(vectorToPoint, normalizedAxis),
			);
			const crossProduct = vec3.cross(
				vec3.create(),
				vectorToPoint,
				normalizedAxis,
			);
			const perpendicularDistance = vec3.length(crossProduct);

			const isInsideDrill =
				projectionDistance <= halfHeight && perpendicularDistance <= radius;
			if (
				(filterPerspective === 'inside' && isInsideDrill) ||
				(filterPerspective === 'outside' && !isInsideDrill)
			) {
				filteredPointsArray.set(point, count * 3);
				filteredScalarsArray[count] = startingScalars[i];
				count++;
			}
		}
	} else if (shape === 'ream') {
		const radiusSquared = radius * radius;
		for (let i = 0; i < numberOfPoints; i++) {
			const point = startingPoints.getPoint(i) as Vector3;
			const distanceSquared =
				(point[0] - center[0]) ** 2 +
				(point[1] - center[1]) ** 2 +
				(point[2] - center[2]) ** 2;
			const isInside = distanceSquared <= radiusSquared;
			if (
				(filterPerspective === 'inside' && isInside) ||
				(filterPerspective === 'outside' && !isInside)
			) {
				filteredPointsArray.set(point, count * 3);
				filteredScalarsArray[count] = startingScalars[i];
				count++;
			}
		}
	} else if (shape === 'resect' && width && height && depth && rotation) {
		const halfWidth = width / 2; // X-axis
		const halfHeight = height / 2; // Z-axis
		const halfDepth = depth / 2; // Y-axis

		for (let i = 0; i < numberOfPoints; i++) {
			const point = startingPoints.getPoint(i) as Vector3;
			const transformedPoint = transformPointRelativeToRotation(
				point,
				center,
				rotation,
			);

			const isInsideResect =
				Math.abs(transformedPoint[0] - center[0]) <= halfWidth &&
				Math.abs(transformedPoint[1] - center[1]) <= halfDepth && // Depth check on Y-axis
				Math.abs(transformedPoint[2] - center[2]) <= halfHeight; // Height check on Z-axis

			if (
				(filterPerspective === 'inside' && isInsideResect) ||
				(filterPerspective === 'outside' && !isInsideResect)
			) {
				filteredPointsArray.set(point, count * 3);
				filteredScalarsArray[count] = startingScalars[i];
				count++;
			}
		}
	}

	return {
		points: filteredPointsArray.subarray(0, count * 3),
		scalars: filteredScalarsArray.subarray(0, count),
	};
}

function showOnlySelectedPoints({
	actorEntry,
	points,
	scalars,
}: {
	actorEntry: ActorEntry;
	points: Float32Array;
	scalars: Float32Array;
}): void {
	const polyData = actorEntry.actor.getMapper()?.getInputData();
	const newPoints = vtkPoints.newInstance();

	newPoints.setData(points, 3);
	polyData.setPoints(newPoints);

	polyData.getPointData().setScalars(
		vtkDataArray.newInstance({
			numberOfComponents: 1,
			values: scalars,
			name: 'Scalars',
		}),
	);
}

export function setPointCloudColorScheme({
	colorScheme,
	threshold,
	vtkState,
}: {
	colorScheme: PointCloudColorScheme;
	threshold: {
		lower: number;
		lut?: any;
		max: number;
		min: number;
		upper: number;
	};
	vtkState: VtkState;
}): {
	globalLut: ReturnType<typeof makeLut>;
	globalRange: ReturnType<typeof calculateGlobalRange>;
} {
	const actors = vtkState.pointCloudActors.map(
		(actorEntry) => actorEntry.actor,
	);

	const globalRange = calculateGlobalRange(actors);
	const globalLut = makeLut(globalRange, colorScheme);

	vtkState.pointCloudActors.forEach(({actor}) => {
		actor
			.getMapper()
			?.getLookupTable()
			.setRange([globalRange.min, globalRange.max]);
		actor.getMapper()?.setLookupTable(globalLut);
	});

	globalLut.setThresholdRange(threshold.lower, threshold.upper);

	if (colorScheme.mapColorsToThreshold) {
		globalLut.setMappingRange(threshold.lower, threshold.upper);
	}

	cornerstone.renderViewports('3d');

	return {globalLut, globalRange};
}

export function setPointCloudOpacity({
	opacity,
	vtkState,
}: {
	opacity: number;
	vtkState: VtkState;
}) {
	for (const {actor} of vtkState.pointCloudActors) {
		actor.getProperty().setOpacity(opacity / 100);
	}

	cornerstone.renderViewports('3d');
}

export function setPointCloudThreshold({
	colorScheme,
	lower,
	lut,
	upper,
}: {
	colorScheme: PointCloudColorScheme;
	lower: number;
	lut: any;
	upper: number;
}) {
	lut.setThresholdRange(lower, upper);

	if (colorScheme.mapColorsToThreshold) {
		lut.setMappingRange(lower, upper);
	}

	cornerstone.renderViewports('3d');
}

export const pointCloudsMaxOpacity = 1;

export const pointCloudsPointSize = 1;
