#!/usr/bin/perl

package Matrix;



#---------------------------------------------------------------------
# method: show
# display matrix as text

sub show
{
   my ($aref, $width) = @_;
   
   my $cols = @{$aref->[0]};
   my $line = "-" x ($cols * ($width+1)) ;
   my $format = "%${width}d " x $cols;
   
   print "$line\n";
   foreach my $row (@$aref)
   {
       printf $format, @$row;
       print "\n";
   }
}

#---------------------------------------------------------------------
# method: show_mathematica
# display matrix as in mathematica input form

sub show_mathematica
{
   my ($aref, $label) = @_;
   
   my ($row, $values, @out);
   
   print "$label = {\n";

   foreach $row (@$aref)
   {
       $values = join ", ", @$row;
       push @out, "   { $values }";
   }
   print join (",\n", @out), "\n}";
   print ";\n\n";
}


#---------------------------------------------------------------------
# method: random
# create a random matrix
#
# n - dimension of the n x n matrix
# p - prime modulus, matrix elements are drawn randomly from field Zp

sub random
{
   my ($class, $n, $p) = @_;
   my $matrix = [];
   
   for (my $i = 0; $i < $n; $i++)
   {
      my $row = [ map int(rand($p)), ( 1.. $n ) ];
      push @$matrix, $row;
   }
   
   bless $matrix, $class;
   return $matrix;
}


#---------------------------------------------------------------------
# method: copy

sub copy
{
   my ($in) = @_;
   
   my $out;
   
   foreach my $row (@$in)
   {
      push @$out, [@$row];
   }
   
   bless $out, ref $in;
   return $out;
}


#---------------------------------------------------------------------
# inverse_mod_p
# 
# computes inverse(a) in the field Zp, p prime
#
# $a - value to invert
# $p - modulus of Zp (p must be prime)
#
# uses Euler Theorem:  a ** phi(p) mod p = 1
#                      a ** (p-1) mod p = 1  p prime
#                      a ** (p-2) x a mod p = 1
#                      a ** (p-2) = a ** (-1)
#
# so for invert a mod p we compute a ** (p-2)
# to do this efficiently we set power = p-2
# then successively double a, multiplying it into
# the final results when there is a 1 in the binary
# representation of power 

sub inverse_mod_p 
{
   my ($a, $p) = @_;
   my $results = 1;
   my $power = $p-2;
   
   $a %= $p;
   
   while ($power)
   {
      printf "\$power=%032b, \$a=%-10d  \$results=%-10d\n", $power, $a, $results if $DEBUG;
      $results = $results * $a % $p if ($power & 0x00000001);
      $a = $a * $a % $p;
      $power >>= 1;
   }

   return $results   
}


#---------------------------------------------------------------------
sub _rowswap
{
   my ($a, $r, $c) = @_;
   
   my $temp = $a->[$c];
   $a->[$c] = $a->[$r];
   $a->[$r] = $temp;   
   
   if ($DEBUG) {
      print "swapping rows $r and $c\n" if $DEBUG;
      matrix_show $a;
   }
} 


#---------------------------------------------------------------------
sub _rowdiv
{
   my ($row, $inverse, $p) = @_;
   print "\n_rowdiv @$row * $inverse\n\n" if $DEBUG;

   foreach my $num (@$row)
   {
      $num *= $inverse;
      $num %= $p;
   }
   matrix_show $a if $DEBUG;
}


#---------------------------------------------------------------------
sub _rowsub
{
   my ($row1, $row2, $mul, $p) = @_;
   print "\n_rowsub @$row1, @$row2, multiplier $mul\n\n" if $DEBUG;
   for (my $c=0; $c < @$row1; $c++)
   {
      $row1->[$c] -= $row2->[$c] * $mul % $p;
      $row1->[$c] %= $p;
   }
   matrix_show $a if $DEBUG;
}


#---------------------------------------------------------------------
# method: invert
#
# computes the inverse of a square matrix
#
# $a   pointer to matrix to invert, matrix in row major order
# $p   prime order of field Zp that underlies the matrix  
#
# uses Gauss-Jordan elimination with partial pivoting (row swaps only)
# since math is done in Zp there is not loss of precision so any nonzero
# value on the diagonal is an equally good pivot value.  If the diagonal
# value for the current column IS zero that row is swapped with one below
# if that has a nonzero value.  If there is no such row then the matrix
# is singular and cannot be inverted

sub invert
{
   my ($a, $p) = @_;
   
   my $size = @$a;
   my $in = $a->copy();

# make $id - identity matrix

   my $id = [];
   bless $id, ref $in;
   
   for (my $i = 0; $i < $size; $i++)
   {
      my $row = [ (0) x $size ];
      $row->[$i] = 1;
      push @$id, $row;
   }
   
   matrix_show $id if $DEBUG;

# row reduce $in (the input matrix) | $id (the identity matrix)

   my ($r, $c);

   for ($c = 0; $c < $size; $c++)
   {
      for ($r = $c; $r < $size; $r++) 
      {
         $pivot = $in->[$r][$c];
         last if $pivot;
      }
      print "pivot is $pivot\n\n" if $DEBUG;
   
      return 0 if $pivot == 0;  # no pivot available, matrix is singular
   
      if ($r != $c)
      {
         _rowswap $in, $r, $c;
         _rowswap $id, $r, $c;
      }
      
      my $inverse = inverse_mod_p ($pivot, $p);
      _rowdiv $in->[$c], $inverse, $p;
      _rowdiv $id->[$c], $inverse, $p;
      
      print "===============================================\n\n" if $DEBUG;
      for ($r = 0; $r < $size; $r++)
      {
         next if $r == $c;
         my $mult = $in->[$r][$c];
         _rowsub $in->[$r], $in->[$c], $mult, $p;
         _rowsub $id->[$r], $id->[$c], $mult, $p;
         if ($DEBUG) {
            print "working on column $c, row $r\n";
            print "input\n";   matrix_show $in;
            print "inverse\n"; matrix_show $id;
         }
      }
      print "===============================================\n\n" if $DEBUG;
   }
   
   return $id;
}



#---------------------------------------------------------------------
# matrix_multiply_mod_p
#
# return matrix product $a x $b
#
# $a  left hand matrix over the field Zp, p prime
# $b  right hand matrix over the field Zp
# $p  prime order of the field Zp

sub multiply
{
   my ($a, $b, $p) = @_;
   
   my $size = @$a;
   
   my ($out, $dot, $sum);
   
   for (my $r = 0; $r < $size; $r++)
   {
      for (my $c = 0; $c < $size; $c++)
      {
         $dot = 0;
         for (my $i = 0; $i < $size; $i++)
         {
            $sum = $a->[$r][$i] * $b->[$i][$c] % $p;
            $dot = $dot + $sum % $p;
         }
         $out->[$r][$c] = $dot % $p;
      }
   }
   
   bless $out, ref $a;
   return $out;
}

sub conjugate
{
   my ($a, $x, $p) = @_;
   
   my $results;
   
   if (undef $x_inv)
   {
      $x_inv = $x->copy();
      $x_inv->invert($p);
   }
   
   $results = $x_inv->multiply($a, $p);
   $results = $results->multiply($x, $p);
   
   return $results;
}


1;