Matrix.java:


/*
 * Matrix.java
 * Created on 2009-04-02
 *
 * Revision History:
 *  2009-04-02
 *    - original
 */


import java.util.ArrayList;
import java.util.Collections;

/**
 * The Matrix class represents a 2D matrix.
 *
 * @author Mark S. Hancock [MSH]
 * @see <a href="http://mathworld.wolfram.com/Matrix.html">Mathworld</a>
 *
 */

public class Matrix
{
    /**
     * The instance variable that stores the matrix data
     */

    private ArrayList<ArrayList<Double>> matrix;

    /**
     * Creates a new matrix with the specified number of rows and columns and
     * zeroes in every cell.
     *
     * @param rows
     *            The number of rows in this matrix
     * @param columns
     *            The number of columns in this matrix
     * @exception IllegalArgumentException
     *                if the number of rows or columns is not positive
     */

    public Matrix(int rows, int columns)
    {
        if (rows <= 0 || columns <= 0)
        {
            throw new IllegalArgumentException(
                    "The number of rows and columns most both be positive.");
        }

        matrix = new ArrayList<ArrayList<Double>>(rows);
        for (int i = 0; i < rows; i++)
        {
            ArrayList<Double> row = new ArrayList<Double>(columns);
            Collections.fill(row, 0.0);
            matrix.add(row);
        }
    }

    /**
     * Returns the number of rows in this matrix.
     *
     * @return the number of rows in this matrix
     */

    public int getRows()
    {
        return matrix.size();
    }

    /**
     * Returns the number of columns in this matrix.
     *
     * @return the number of columns in this matrix
     */

    public int getColumns()
    {
        return matrix.get(0).size();
    }

    /**
     * Gets the value of the element at row i and column j.
     *
     * @param i
     *            The row number
     * @param j
     *            The column number
     * @return The value at row i and column j
     */

    public double get(int i, int j)
    {
        return matrix.get(i).get(j);
    }

    /**
     * Sets the value of the element at row i and column j.
     *
     * @param i
     *            The row number
     * @param j
     *            The column number
     * @param value
     *            The value to set the element at row i and column j to
     */

    public void set(int i, int j, double value)
    {
        matrix.get(i).set(j, value);
    }

    /**
     * Multiplies this matrix by the matrix stored in 'right'. The number of
     * columns of this matrix must equal the number of rows in 'right'. The
     * resulting matrix will have the number of rows that this matrix has and
     * the number of columns in 'right'. The ith row and jth column of the
     * resulting matrix is calculated as the dot product of the ith row of this
     * matrix with the jth column of 'right'. For example:
     *
     * <pre>
     *      [ r00 r01 ]   [ a00 a01 a02 ]   [ b00 b01 ]
     *      [ r10 r11 ] = [ a10 a11 a12 ] x [ b10 b11 ]
     *                                      [ b20 b21 ]
     *                                      
     *                  = [ (a00xb00 + a01xb10 + a02xb20) (a00xb01 + a01xb11 + a02xb21) ]
     *                    [ (a10xb00 + a11xb10 + a12xb20) (a10xb01 + a11xb11 + a12xb21) ]
     * </pre>
     *
     * @param right
     *            the matrix to multiply this matrix by
     * @return the result of the matrix multiplication
     */

    public Matrix multiply(Matrix right)
    {
        if (this.getColumns() != right.getRows())
        {
            throw new IllegalArgumentException("Cannot multiply a "
                    + this.getRows() + " by " + this.getColumns()
                    + " matrix with a " + right.getRows() + " by "
                    + right.getColumns() + " matrix");
        }

        Matrix result = new Matrix(this.getRows(), right.getColumns());
        for (int i = 0; i < result.getRows(); i++)
        {
            for (int j = 0; j < result.getColumns(); j++)
            {
                double dotProduct = 0.0;
                for (int k = 0; k < this.getColumns(); k++)
                {
                    dotProduct += this.get(i, k) * right.get(k, j);
                }

                result.set(i, j, dotProduct);
            }
        }

        return result;
    }
}