/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.image.processing.face.alignment;

import Jama.Matrix;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.openimaj.image.FImage;
import org.openimaj.image.Image;
import org.openimaj.image.ImageUtilities;
import org.openimaj.image.processing.face.alignment.AffineAligner;
import org.openimaj.image.processing.face.alignment.FaceAligner;
import org.openimaj.image.processing.face.detection.keypoints.FKEFaceDetector;
import org.openimaj.image.processing.face.detection.keypoints.FacialKeypoint;
import org.openimaj.image.processing.face.detection.keypoints.KEDetectedFace;
import org.openimaj.image.processing.transform.PiecewiseMeshWarp;
import org.openimaj.image.processor.ImageProcessor;
import org.openimaj.math.geometry.point.Point2d;
import org.openimaj.math.geometry.point.Point2dImpl;
import org.openimaj.math.geometry.shape.Polygon;
import org.openimaj.math.geometry.shape.Shape;
import org.openimaj.math.geometry.transforms.TransformUtilities;
import org.openimaj.util.pair.Pair;

public class MeshWarpAligner
implements FaceAligner<KEDetectedFace> {
    private static final String[][] DEFAULT_MESH_DEFINITION = new String[][]{{"EYE_LEFT_RIGHT", "EYE_RIGHT_LEFT", "NOSE_MIDDLE"}, {"EYE_LEFT_LEFT", "EYE_LEFT_RIGHT", "NOSE_LEFT"}, {"EYE_RIGHT_RIGHT", "EYE_RIGHT_LEFT", "NOSE_RIGHT"}, {"EYE_LEFT_RIGHT", "NOSE_LEFT", "NOSE_MIDDLE"}, {"EYE_RIGHT_LEFT", "NOSE_RIGHT", "NOSE_MIDDLE"}, {"MOUTH_LEFT", "MOUTH_RIGHT", "NOSE_MIDDLE"}, {"MOUTH_LEFT", "NOSE_LEFT", "NOSE_MIDDLE"}, {"MOUTH_RIGHT", "NOSE_RIGHT", "NOSE_MIDDLE"}, {"MOUTH_LEFT", "NOSE_LEFT", "EYE_LEFT_LEFT"}, {"MOUTH_RIGHT", "NOSE_RIGHT", "EYE_RIGHT_RIGHT"}};
    private static final Point2d P0 = new Point2dImpl(0.0f, 0.0f);
    private static final Point2d P1 = new Point2dImpl(80.0f, 0.0f);
    private static final Point2d P2 = new Point2dImpl(80.0f, 80.0f);
    private static final Point2d P3 = new Point2dImpl(0.0f, 80.0f);
    private static FacialKeypoint[] canonical = MeshWarpAligner.loadCanonicalPoints();
    String[][] meshDefinition = DEFAULT_MESH_DEFINITION;
    FImage mask;

    public MeshWarpAligner() {
        this(DEFAULT_MESH_DEFINITION);
    }

    public MeshWarpAligner(String[][] meshDefinition) {
        this.meshDefinition = meshDefinition;
        List<Pair<Shape>> mesh = this.createMesh(canonical);
        this.mask = new FImage((int)P2.getX(), (int)P2.getY());
        this.mask.fill(1.0f);
        this.mask = (FImage)this.mask.processInplace((ImageProcessor)new PiecewiseMeshWarp(mesh));
    }

    private static FacialKeypoint[] loadCanonicalPoints() {
        FacialKeypoint[] points = new FacialKeypoint[AffineAligner.Pmu[0].length];
        for (int i = 0; i < points.length; ++i) {
            points[i] = new FacialKeypoint(FacialKeypoint.FacialKeypointType.valueOf(i));
            points[i].position = new Point2dImpl(2.0f * AffineAligner.Pmu[0][i] - 40.0f, 2.0f * AffineAligner.Pmu[1][i] - 40.0f);
        }
        return points;
    }

    protected FacialKeypoint[] getActualPoints(FacialKeypoint[] keys, Matrix tf0) {
        FacialKeypoint[] points = new FacialKeypoint[AffineAligner.Pmu[0].length];
        for (int i = 0; i < points.length; ++i) {
            points[i] = new FacialKeypoint(FacialKeypoint.FacialKeypointType.valueOf(i));
            points[i].position = new Point2dImpl((Point2d)FacialKeypoint.getKeypoint((FacialKeypoint[])keys, (FacialKeypoint.FacialKeypointType)FacialKeypoint.FacialKeypointType.valueOf((int)i)).position.transform(tf0));
        }
        return points;
    }

    protected List<Pair<Shape>> createMesh(FacialKeypoint[] det) {
        ArrayList<Pair<Shape>> shapes = new ArrayList<Pair<Shape>>();
        for (String[] vertDefs : this.meshDefinition) {
            Polygon p1 = new Polygon();
            Polygon p2 = new Polygon();
            for (String v : vertDefs) {
                p1.getVertices().add(this.lookupVertex(v, det));
                p2.getVertices().add(this.lookupVertex(v, canonical));
            }
            shapes.add((Pair<Shape>)new Pair((Object)p1, (Object)p2));
        }
        return shapes;
    }

    private Point2d lookupVertex(String v, FacialKeypoint[] pts) {
        if (v.equals("P0")) {
            return P0;
        }
        if (v.equals("P1")) {
            return P1;
        }
        if (v.equals("P2")) {
            return P2;
        }
        if (v.equals("P3")) {
            return P3;
        }
        return FacialKeypoint.getKeypoint((FacialKeypoint[])pts, (FacialKeypoint.FacialKeypointType)FacialKeypoint.FacialKeypointType.valueOf((String)v)).position;
    }

    @Override
    public FImage align(KEDetectedFace descriptor) {
        float scalingX = P2.getX() / (float)descriptor.getFacePatch().width;
        float scalingY = P2.getY() / (float)descriptor.getFacePatch().height;
        Matrix tf0 = TransformUtilities.scaleMatrix((double)scalingX, (double)scalingY);
        Matrix tf = tf0.inverse();
        FImage J = FKEFaceDetector.pyramidResize(descriptor.getFacePatch(), tf);
        FImage smallpatch = FKEFaceDetector.extractPatch(J, tf, 80, 0);
        return this.getWarpedImage(descriptor.getKeypoints(), smallpatch, tf0);
    }

    protected FImage getWarpedImage(FacialKeypoint[] kpts, FImage patch, Matrix tf0) {
        FacialKeypoint[] det = this.getActualPoints(kpts, tf0);
        List<Pair<Shape>> mesh = this.createMesh(det);
        FImage newpatch = (FImage)patch.process((ImageProcessor)new PiecewiseMeshWarp(mesh));
        return newpatch;
    }

    @Override
    public FImage getMask() {
        return this.mask;
    }

    public void readBinary(DataInput in) throws IOException {
        int sz = in.readInt();
        this.meshDefinition = new String[sz][];
        for (int i = 0; i < this.meshDefinition.length; ++i) {
            sz = in.readInt();
            this.meshDefinition[i] = new String[sz];
            for (int j = 0; j < this.meshDefinition[i].length; ++j) {
                this.meshDefinition[i][j] = in.readUTF();
            }
        }
        this.mask = ImageUtilities.readF((DataInput)in);
    }

    public byte[] binaryHeader() {
        return this.getClass().getName().getBytes();
    }

    public void writeBinary(DataOutput out) throws IOException {
        out.writeInt(this.meshDefinition.length);
        for (String[] def : this.meshDefinition) {
            out.writeInt(def.length);
            for (String s : def) {
                out.writeUTF(s);
            }
        }
        ImageUtilities.write((Image)this.mask, (String)"png", (DataOutput)out);
    }
}

