IT WORKS!?!?

I have spent teh last week attempting to implement proper BFGS optimisation to replace my gradient descent. I have been encountering mysterious problem after mysterious problem, from random noise in my deterministic data to fake derivatives being better than real derivatives. But now it works. And I need to push this _right now_ before it breaks.
This commit is contained in:
Justin Kunimune 2018-01-29 21:05:56 -10:00
parent 76f0ce6f29
commit b51f632cf5
6 changed files with 541 additions and 97 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 17 KiB

View File

@ -1,11 +1,12 @@
We got the best AuthaPower projections using:
t0=0.5431203354626852; (0.42067543582420297, 0.1080699456387831)
t0=0.5520863481348803; (0.41151588635691044, 0.10836740216767649)
t0=0.5649999999999998; (0.39832544719243623, 0.1117772180300048)
t0=0.5799999999999997; (0.3830071497982258, 0.11998391934038882)
t0=0.6070810562218636; (0.3553615292439163, 0.14328643594135124)
t0=0.9249999999999975; (0.041512185218508045, 0.5598094651536242)
t0=0.9399999999999974; (0.032950786280409324, 0.5768242162838197)
t0=0.9475257414659873; (0.030720805053792846, 0.5851535892397473)
t0=0.9527575327203568; (0.03025041972521741, 0.5911164500406887)
We got the best Tobler Hyperelliptical projections using:
t0=30.506558786136164; t1=0.3975383525206906; t2=3.997720064276123; (2.5245792765063137E-4, 0.49042998348577777)
We got the best Winkel Tripel projections using:
t0=17.957933140628256; (0.25932019508548493, 0.36444804618226867)
We got the best TetraPower projections using:
t0=0.7251510971705017; t1=0.8799415807400688; t2=0.8572850133622626; (0.422750395775632, 0.13826331216471252)
We got the best AuthaPower projections using:
t0=0.5500000000000003; (0.41312749977668417, 0.08828966383834298)

View File

@ -26,6 +26,7 @@ package apps;
import java.io.File;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.function.Function;
import javax.imageio.ImageIO;
@ -38,14 +39,15 @@ import javafx.scene.chart.NumberAxis;
import javafx.scene.chart.XYChart.Data;
import javafx.scene.chart.XYChart.Series;
import javafx.stage.Stage;
import maps.Cylindrical;
import maps.Lenticular;
import maps.Misc;
import maps.Projection;
import maps.Arbitrary;
import maps.Cylindrical;
import maps.Misc;
import maps.Polyhedral;
import maps.Projection;
import maps.Tobler;
import maps.WinkelTripel;
import utils.linalg.Matrix;
import utils.linalg.Vector;
/**
* An application to compare and optimize map projections
@ -54,16 +56,22 @@ import maps.WinkelTripel;
*/
public class MapOptimizer extends Application {
private static final Projection[] EXISTING_PROJECTIONS = { Cylindrical.HOBO_DYER,
Arbitrary.ROBINSON, Lenticular.VAN_DER_GRINTEN, Misc.PEIRCE_QUINCUNCIAL };
private static final Projection[] EXISTING_PROJECTIONS = { Cylindrical.BEHRMANN,
Arbitrary.ROBINSON, Cylindrical.GALL_STEREOGRAPHIC, Misc.PEIRCE_QUINCUNCIAL };
private static final Projection[] PROJECTIONS_TO_OPTIMIZE = { Tobler.TOBLER,
WinkelTripel.WINKEL_TRIPEL, Polyhedral.TETRAPOWER, Polyhedral.AUTHAPOWER };
private static final double[] WEIGHTS = { 0., .125, .25, .375, .5, .625, .75, .875, 1. };
private static final int NUM_BRUTE_FORCE = 100;
private static final int NUM_DESCENT = 30;
private static final double DEL_X = 0.01;
// private static final double[] WEIGHTS = { 0., .125, .25, .375, .5, .625, .75, .875, 1. };
private static final double[] WEIGHTS = {0};
private static final int NUM_SAMPLE = 1;
private static final int NUM_BRUTE_FORCE = 30;
private static final int NUM_BFGS_ITERATE = 6;
private static final double GOLDSTEIN_C = 0.5;
private static final double BACKTRACK_TAU = 0.5;
private static final double BACKTRACK_ALF0 = 4;
private static final double DEL_X = 0.05;
private LineChart<Number, Number> chart;
private static final double[][][] GLOBE = Projection.hemisphere(0.01);
public static final void main(String[] args) {
@ -79,12 +87,12 @@ public class MapOptimizer extends Application {
new NumberAxis("Shape distortion", 0, .6, 0.1));
chart.setCreateSymbols(true);
chart.setAxisSortingPolicy(SortingPolicy.NONE);
double[][][] globe = Projection.globe(0.01);
PrintStream log = new PrintStream(new File("output/parameters.txt"));
chart.getData().add(analyzeAll(globe, EXISTING_PROJECTIONS));
chart.getData().add(analyzeAll(EXISTING_PROJECTIONS));
for (Projection p: PROJECTIONS_TO_OPTIMIZE)
chart.getData().add(optimizeFamily(p, globe, log));
chart.getData().add(optimiseFamily(p, log));
System.out.println("Total time elapsed: " + (System.currentTimeMillis() - startTime) / 60000. + "m");
@ -99,49 +107,107 @@ public class MapOptimizer extends Application {
}
private static Series<Number, Number> analyzeAll(double[][][] points, Projection... projs) { //analyze and plot the specified preexisting map projections.
private static Series<Number, Number> analyzeAll(Projection... projs) { //analyze and plot the specified preexisting map projections.
System.out.println("Analyzing " + Arrays.toString(projs));
Series<Number, Number> output = new Series<Number, Number>(); //These projections must not be parametrized
output.setName("Basic Projections");
for (Projection proj : projs)
if (!proj.isParametrized())
output.getData().add(plotDistortion(points, proj, new double[0]));
output.getData().add(plotDistortion(proj, new double[0]));
return output;
}
private static Series<Number, Number> optimizeFamily(
Projection proj, double[][][] points, PrintStream log) { //optimize and plot some maps of a given family
private static Data<Number, Number> plotDistortion(Projection proj, double[] params) {
double[] distortion = proj.avgDistortion(GLOBE, proj.getDefaultParameters());
return new Data<Number, Number>(distortion[0], distortion[1]);
}
private static final double weighDistortion(Projection proj, double[] params, double weight) {
double areaDist = 0, anglDist = 0;
for (int i = 0; i < NUM_SAMPLE; i ++) {
double[] mcParams = new double[params.length];
for (int j = 0; j < mcParams.length; j ++)
mcParams[j] = params[j];// + (Math.random()-.5)*1e-3; //is it weird that I'm introducing stochasticity into this?
double[] distortions = proj.avgDistortion(GLOBE, mcParams);
areaDist += distortions[0];
anglDist += distortions[1];
}
return areaDist/NUM_SAMPLE*weight + anglDist/NUM_SAMPLE*(1-weight);
}
private static Series<Number, Number> optimiseFamily(
Projection proj, PrintStream log) { //optimize and plot some maps of a given family
System.out.println("Optimizing " + proj.getName());
final double[][] currentBest = new double[WEIGHTS.length][3 + proj.getNumParameters()]; //the 0-3 cols are the min distortions for each weight, the
for (int k = 0; k < WEIGHTS.length; k++) //other cols are the values of k and n that caused that
currentBest[k][0] = Integer.MAX_VALUE;
final double[][] bounds = proj.getParameterValues();
final double[] params = new double[proj.getNumParameters()];
final double[][] best = new double[WEIGHTS.length][proj.getNumParameters()];
for (int k = 0; k < WEIGHTS.length; k ++) {
final double weighFactor = WEIGHTS[k];
double[] currentBest = bruteForceMinimise(
(params) -> weighDistortion(proj, params, weighFactor),
proj.getParameterValues());
best[k] = bfgsMinimise(
(params) -> weighDistortion(proj, params, weighFactor),
currentBest);
}
final Series<Number, Number> output = new Series<Number, Number>();
output.setName(proj.getName());
log.println("We got the best " + proj.getName() + " projections using:"); //now log it
for (double[] bestForWeight : best) { //for each weight
log.print("\t");
for (int i = 0; i < proj.getNumParameters(); i++)
log.print("t" + i + "=" + bestForWeight[i] + "; "); //print the parameters used
double[] distortion = proj.avgDistortion(GLOBE, bestForWeight);
log.println("\t(" + distortion[0] + ", " + distortion[1] + ")"); //print the resulting distortion
output.getData().add(new Data<Number, Number>(distortion[0], distortion[1])); //plot it
}
log.println();
return output;
}
/**
* Returns the parameters that minimise the function, based on a simple brute-force
* parameter sweep.
* @param func The function to minimise
* @param bounds Parameter limits for each argument
* @param numTries The approximate number of samples to take
* @return An array containing the best input to func that it found
*/
private static double[] bruteForceMinimise(Function<double[], Double> func, double[][] bounds) {
System.out.println("BF = [");
final double[] params = new double[bounds.length];
for (int i = 0; i < params.length; i++)
params[i] = bounds[i][0]; // initialize params
double bestValue = Double.POSITIVE_INFINITY;
double[] bestParams = new double[params.length];
bruteForceLoop:
while (true) { // start with brute force
double[] distortions = proj.avgDistortion(points, params);
System.out.println(Arrays.toString(params) + ": " + Arrays.toString(distortions));
for (int k = 0; k < WEIGHTS.length; k++) {
final double avgDist = weighDistortion(WEIGHTS[k], distortions);
if (avgDist < currentBest[k][0]) {
currentBest[k][0] = avgDist;
currentBest[k][1] = distortions[0];
currentBest[k][2] = distortions[1];
System.arraycopy(params, 0, currentBest[k], 3, params.length);
}
while (true) { // run until you've exhausted the parameter space
double avgDist = func.apply(params);
if (avgDist < bestValue) {
bestValue = avgDist;
bestParams = params.clone();
}
for (int i = 0; i < params.length; i ++)
System.out.print(params[i]+", ");
System.out.println(avgDist+";");
int i;
for (i = 0; i <= params.length; i++) { // iterate the parameters
if (i == params.length)
break bruteForceLoop; // if you made it through all the parameters without breaking, you're done!
if (i == params.length) {
System.out.println("];");
return bestParams; // if you made it through all the parameters without breaking, you're done!
}
final double step = (bounds[i][1] - bounds[i][0]) /
Math.floor(Math.pow(NUM_BRUTE_FORCE, 1./params.length));
@ -153,60 +219,119 @@ public class MapOptimizer extends Application {
}
}
}
}
private static double[] bfgsMinimise(Function<double[], Double> arrFunction, double[] x0) { //The Broyden-Fletcher-Goldfarb-Shanno algorithm
System.out.println("BFGS = [");
final int n = x0.length;
final Matrix I = Matrix.identity(n);
final Function<Vector, Double> func = (vec) -> arrFunction.apply(vec.asArray());
final double h = 1e-7;
for (int k = 0; k < WEIGHTS.length; k++) { // now do gradient descent
System.arraycopy(currentBest[k], 3, params, 0, params.length);
System.out.println("Starting gradient descent with weight " + WEIGHTS[k] + " and initial parameters "
+ Arrays.toString(params));
double fr0 = currentBest[k][0];
double[] frd = new double[params.length];
for (int i = 0; i < NUM_DESCENT; i++) {
System.out.println(Arrays.toString(params) + " -> " + fr0);
for (int j = 0; j < params.length; j++) {
params[j] += h;
frd[j] = weighDistortion(WEIGHTS[k], proj.avgDistortion(points, params)); // calculate the distortion nearby
params[j] -= h;
}
for (int j = 0; j < params.length; j++)
params[j] -= (frd[j] - fr0)/h * Math.pow(bounds[j][1]-bounds[j][0],2) * DEL_X; // use that to approximate the gradient and go in the other direction
final double[] distsHere = proj.avgDistortion(points, params);
fr0 = weighDistortion(WEIGHTS[k], distsHere); // calculate the distortion here
if (fr0 <= currentBest[k][0]) { // make sure we are still descending
currentBest[k][0] = fr0; // and save the current datum
System.arraycopy(distsHere, 0, currentBest[k], 1, 2);
System.arraycopy(params, 0, currentBest[k], 3, params.length);
}
else
break;
Vector xk = new Vector(x0); //initial variable values
double fxk = func.apply(xk);
// System.out.println(hessian(func, xk, fxk));
// System.out.println(grad(func, xk, fxk));
// Matrix H = hessian(func, xk, fxk);
Matrix Binv = hessian(func, xk, fxk).inverse();
// Matrix Binv = Matrix.identity(n); //initial approximate Hessian inverse
Vector gradFxk = grad(func, xk, fxk); //function at current location
// System.out.println("anetauson!");
// System.out.print("A"+NUM_SAMPLE+" = [");
// for (double d = 0; d <= 1e-1; d += 5e-4) {
// xk = new Vector(17.8, -17.8);
// Vector dx = Vector.unit(0, n).times(d);
// double f = func.apply(xk.minus(new Vector(-1,1)).plus(dx));
// System.out.println(xk.minus(new Vector(-1,1)).plus(dx).getElement(0)+", "+f+";");
// }
// System.out.println("];");
for (int k = 0; k < NUM_BFGS_ITERATE; k ++) { //(I'm not sure how to test for convergence here, so I'm just running a set number of iterations)
Vector pk = Vector.fromMatrix(Binv.times(gradFxk)); //apply Newton's method for initial step direction
pk = pk.times(-Math.signum(pk.dot(gradFxk))); //but make sure it points downhill
double alfk = BACKTRACK_ALF0; //perform a backtracking line search to find the best alpha
double fxkp1 = func.apply(xk.plus(pk.times(alfk)));
while ((!Double.isFinite(fxkp1) || fxkp1 > fxk + alfk*pk.dot(gradFxk)*GOLDSTEIN_C)) {
if (alfk <= 1e-5)
return xk.asArray(); //a simple way to check for convergence: if xk gets ridiculously small, we're done here.
// for (double d: xk.plus(pk.times(alfk)).asArray())
// System.out.print(d+", ");
// System.out.println(fxkp1+";");
alfk *= BACKTRACK_TAU;
fxkp1 = func.apply(xk.plus(pk.times(alfk)));
}
Vector sk = pk.times(alfk); //iterate
Vector xkp1 = xk.plus(sk);
Vector gradFxkp1 = grad(func, xkp1, fxkp1); //compute new gradient
Vector yk = gradFxkp1.minus(gradFxk); //and gradient change
Matrix a = I.minus(sk.times(yk.T()).times(1/yk.dot(sk)));
Matrix b = sk.times(sk.T()).times(1/yk.dot(sk));
// Binv = hessian(func, xkp1, fxkp1).inverse();
Binv = a.times(Binv).times(a.T()).plus(b); //update Binv
xk = xkp1;
fxk = fxkp1;
gradFxk = gradFxkp1; //and save the gradient
}
System.out.println("];");
return xk.asArray();
}
private static Vector grad(Function<Vector, Double> f, Vector x, double fx) {
final int n = x.getLength();
Vector gradF = new Vector(n); //compute the gradient
for (int i = 0; i < n; i ++) {
Vector xph = x.plus(Vector.unit(i,n).times(DEL_X));
double fxph = f.apply(xph);
gradF.setElement(i, (fxph-fx)/DEL_X);
}
for (double d: x.asArray())
System.out.print(d+", ");
// System.out.println(gradF.getElement(0)+", "+gradF.getElement(1)+";");
System.out.println(fx+";");
return gradF;
}
private static Matrix hessian(Function<Vector, Double> f, Vector x, double fx) {
final int n = x.getLength();
double[] values = new double[(int)Math.pow(3, n-1)*2+1]; //points in array placed with ternary coordinates
values[0] = fx;
for (int i = 0; i < n; i ++) { //for each primary dimension
for (int j = i; j < n; j ++) { //for each secondary dimension (skip a few to prevent redundant calculations)
int k = (int)Math.pow(3, i) + (int)Math.pow(3, j); //calculate the ternary index
Vector dx = Vector.unit(i, n).plus(Vector.unit(j, n)).times(DEL_X); //go a bit in both directions
values[k] = f.apply(x.plus(dx)); //calculate and save
}
int k = (int)Math.pow(3, i); //do the same with just i, no j
Vector dx = Vector.unit(i, n).times(DEL_X);
values[k] = f.apply(x.plus(dx));
}
// System.out.println(Arrays.toString(values));
Matrix h = new Matrix(n, n);
for (int i = 0; i < n; i ++) { //compute the derivatives and fill the matrix
for (int j = i; j < n; j ++) {
int dxi = (int)Math.pow(3, i);
int dxj = (int)Math.pow(3, j);
double dfdx0 = (values[dxi] - values[0])/DEL_X;
double dfdx1 = (values[dxi+dxj] - values[dxj])/DEL_X;
// System.out.println(dfdx0+", "+dfdx1);
double d2fdx2 = (dfdx1 - dfdx0)/DEL_X;
h.setElement(i, j, d2fdx2);
h.setElement(j, i, d2fdx2);
}
}
final Series<Number, Number> output = new Series<Number, Number>();
output.setName(proj.getName());
log.println("We got the best " + proj.getName() + " projections using:");
for (double[] best : currentBest) {
log.print("\t");
for (int i = 0; i < params.length; i++)
log.print("t" + i + "=" + best[3 + i] + "; ");
log.println("\t(" + best[1] + ", " + best[2] + ")");
output.getData().add(new Data<Number, Number>(best[1], best[2]));
}
log.println();
return output;
return h;
}
private static final double weighDistortion(double weight, double... distortions) {
return distortions[0]*weight + distortions[1]*(1-weight);
}
private static Data<Number, Number> plotDistortion(double[][][] pts, Projection proj, double[] params) {
double[] distortion = proj.avgDistortion(pts, proj.getDefaultParameters());
return new Data<Number, Number>(distortion[0], distortion[1]);
}
}

View File

@ -305,6 +305,17 @@ public abstract class Projection {
}
public static double[][][] hemisphere(double dt) { //like globe(), but for the eastern hemisphere. Good for doing projections that are symmetrical in longitude (i.e. pretty much all of them)
List<double[]> points = new ArrayList<double[]>();
for (double phi = -Math.PI/2+dt/2; phi < Math.PI/2; phi += dt) { // make sure phi is never exactly +-tau/4
for (double lam = dt/Math.cos(phi)/2; lam < Math.PI; lam += dt/Math.cos(phi)) {
points.add(new double[] {phi, lam});
}
}
return new double[][][] {points.toArray(new double[0][])};
}
public double[] avgDistortion(double[][][] points, double[] params) {
this.setParameters(params);
return avgDistortion(points);
@ -368,7 +379,7 @@ public abstract class Projection {
final double s1ps2 = Math.hypot((pE[0]-pC[0])+(pN[1]-pC[1]), (pE[1]-pC[1])-(pN[0]-pC[0]));
final double s1ms2 = Math.hypot((pE[0]-pC[0])-(pN[1]-pC[1]), (pE[1]-pC[1])+(pN[0]-pC[0]));
output[1] = Math.abs(Math.log(Math.abs((s1ps2-s1ms2)/(s1ps2+s1ms2)))); //the first output is the shape (angle) distortion
if (Math.abs(output[1]) > 25)
if (output[1] > 25)
output[1] = Double.NaN; //discard outliers
return output;

View File

@ -0,0 +1,199 @@
/**
* MIT License
*
* Copyright (c) 2017 Justin Kunimune
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package utils.linalg;
/**
* A two-dimensional array of numbers to which you can do linear algebra
* FYI these Matrices index from zero because iatshuip MatLab
* This class isn't very efficient, but I'm using it on matrices of max length 3, so
*
* @author jkunimune
*/
public class Matrix {
private final double[][] values;
public Matrix(int n, int m) {
this.values = new double[n][m];
}
public Matrix(double[]... values) {
this.values = values;
}
public static Matrix identity(int n) {
Matrix identity = new Matrix(n, n);
for (int i = 0; i < n; i ++)
identity.setElement(i, i, 1);
return identity;
}
public int getHeight() {
return this.values.length;
}
public int getWidth() {
if (this.getHeight() > 0)
return this.values[0].length;
else
return 0;
}
public double getElement(int i, int j) {
return this.values[i][j];
}
public void setElement(int i, int j, double val) {
this.values[i][j] = val;
}
protected double[][] getArray() {
return values;
}
public Matrix T() {
return this.transpose();
}
public Matrix transpose() {
Matrix thisT = new Matrix(this.getWidth(), this.getHeight());
for (int i = 0; i < this.getWidth(); i ++)
for (int j = 0; j < this.getHeight(); j ++)
thisT.setElement(i, j, this.getElement(j, i));
return thisT;
}
public Matrix inverse() {
if (this.getWidth() != this.getHeight())
throw new IllegalArgumentException("Only square matrices, have inverses. "+this+" is not square.");
double det = this.determinant();
if (det == 0)
System.err.println(this+" is singular and thus has an infinite determinant.");
Matrix inv = new Matrix(this.getHeight(), this.getWidth());
for (int i = 0; i < this.getHeight(); i ++) {
for (int j = 0; j < this.getWidth(); j ++) {
inv.setElement(i, j, this.cofactor(j, i)/det);
}
}
return inv;
}
public double determinant() {
if (this.getWidth() != this.getHeight())
throw new IllegalArgumentException("Only square matrices have determinants. "+this+" is not square.");
if (this.getHeight() == 0)
return 1;
if (this.getHeight() == 1)
return this.getElement(0, 0);
if (this.getHeight() == 2)
return this.getElement(0, 0) * this.getElement(1, 1) -
this.getElement(0, 1) * this.getElement(1, 0);
double det = 0;
for (int j = 0; j < this.getWidth(); j ++) {
if (j%2 == 0)
det += this.getElement(0, j)*this.submatrix(0, j).determinant();
else
det -= this.getElement(0, j)*this.submatrix(0, j).determinant();
}
return det;
}
public Matrix submatrix(double I, double J) {
Matrix sub = new Matrix(this.getHeight()-1, this.getWidth()-1);
for (int i = 0; i < this.getHeight()-1; i ++)
for (int j = 0; j < this.getWidth()-1; j ++)
sub.setElement(i, j, this.getElement(i<I ? i : i+1, j<J ? j : j+1));
return sub;
}
public double cofactor(double i, double j) {
if ((i + j)%2 == 0)
return this.submatrix(i, j).determinant();
else
return -this.submatrix(i, j).determinant();
}
public Matrix plus(Matrix that) {
if (this.getHeight() != that.getHeight() || this.getWidth() != that.getWidth())
throw new IllegalArgumentException("Matrix dimensions must match. Cannot multiply\n"+this+" by\n"+that);
Matrix sum = new Matrix(this.getHeight(), this.getWidth());
for (int i = 0; i < this.getHeight(); i ++)
for (int j = 0; j < this.getWidth(); j ++)
sum.setElement(i, j, this.getElement(i, j) + that.getElement(i, j));
return sum;
}
public Matrix minus(Matrix that) {
return this.plus(that.times(-1));
}
public Matrix times(double c) {
Matrix out = new Matrix(this.getHeight(), this.getWidth());
for (int i = 0; i < this.getHeight(); i ++)
for (int j = 0; j < this.getWidth(); j ++)
out.setElement(i, j, this.getElement(i, j)*c);
return out;
}
public Matrix times(Matrix that) {
if (this.getWidth() != that.getHeight())
throw new IllegalArgumentException("Matrix dimensions must match. Cannot multiply\n"+this+" by\n"+that);
Matrix product = new Matrix(this.getHeight(), that.getWidth());
for (int i = 0; i < this.getHeight(); i ++)
for (int j = 0; j < that.getWidth(); j ++)
for (int k = 0; k < this.getWidth(); k ++)
product.setElement(i, j, product.getElement(i, j)
+ this.getElement(i, k) * that.getElement(k, j));
return product;
}
public String toString() {
String str = "[ ";
for (int i = 0; i < this.getHeight(); i ++) {
for (int j = 0; j < this.getWidth(); j ++)
str += this.getElement(i, j) + ", ";
str = str.substring(0, str.length()-2) + ";\n ";
}
return str.substring(0, str.length()-4) + " ]";
}
public static final void main(String[] args) {
Matrix a = new Matrix(new double[][] {{1,4,7},{3,0,5},{-1,9,11}});
System.out.println(a); //TODO: delete this later
System.out.println(a.determinant());
System.out.println(a.submatrix(1, 2));
System.out.println(a.cofactor(1, 2));
System.out.println(a.inverse());
System.out.println(a.inverse().transpose());
System.out.println(a.transpose().times(a.inverse().transpose()));
}
}

View File

@ -0,0 +1,108 @@
/**
* MIT License
*
* Copyright (c) 2017 Justin Kunimune
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package utils.linalg;
/**
* A column vector to which you can do linear algebra.
* This class isn't very efficient, but I'm using it on vectors of max length 3, so
*
* @author jkunimune
*/
public class Vector extends Matrix {
public Vector(int n) {
super(n, 1);
}
public Vector(double... values) {
super(values.length, 1);
for (int i = 0; i < values.length; i ++)
this.setElement(i, 0, values[i]);
}
private Vector(double[][] values) {
super(values);
if (values[0].length != 1)
throw new IllegalArgumentException("Matrix has width "+values[0].length+" and can therefore not be converted to a column vector.");
}
public static Vector fromMatrix(Matrix mat) {
return new Vector(mat.getArray());
}
public static Vector unit(int i, int n) {
Vector iHat = new Vector(n);
iHat.setElement(i, 1);
return iHat;
}
public int getLength() {
return this.getHeight();
}
public double getElement(int i) {
return this.getElement(i, 0);
}
public void setElement(int i, double val) {
this.setElement(i, 0, val);
}
public double[] asArray() {
double[] arr = new double[this.getLength()];
for (int i = 0; i < arr.length; i ++)
arr[i] = this.getElement(i, 0);
return arr;
}
public Vector plus(Vector that) {
return new Vector(super.plus(that).getArray());
}
public Vector minus(Vector that) {
return new Vector(super.minus(that).getArray());
}
public Vector times(double c) {
return new Vector(super.times(c).getArray());
}
public double dot(Vector that) {
return this.transpose().times(that).getElement(0, 0);
}
public double norm() {
double s = 0;
for (int i = 0; i < this.getLength(); i ++)
s += Math.pow(this.getElement(i), 2);
return Math.sqrt(s);
}
public Vector hat() {
return this.times(1/this.norm());
}
}