CFEM.floatlib

Require Import VST.floyd.proofauto.
Require Import vcfloat.VCFloat.
Require Import Coq.Relations.Relations Coq.Classes.Morphisms Coq.Classes.RelationPairs Coq.Classes.RelationClasses.
Require Export vcfloat.FPStdLib.

Definition matrix t := list (list (ftype t)).
Definition vector t := list (ftype t).

Definition dotprod {NAN: Nans} {t: type} (v1 v2: list (ftype t)) : ftype t :=
  fold_left (fun s x12BFMA (fst x12) (snd x12) s)
                (List.combine v1 v2) (Zconst t 0).

Definition norm2 {NAN: Nans} {t} (v: vector t) := dotprod v v.

Definition matrix_vector_mult {NAN: Nans}{t: type} (m: matrix t) (v: vector t) : vector t :=
      map (fun rowdotprod row v) m.

Definition matrix_matrix_mult {NAN: Nans}{t: type} (m1 m2: matrix t) : matrix t :=
  map (matrix_vector_mult m1) m2.

Definition matrix_cols {t} (m: matrix t) cols :=
    Forall (fun rZlength r = cols) m.

Definition matrix_rows {t} (m: matrix t) : Z := Zlength m.

Definition map2 {A B C: Type} (f: A B C) al bl :=
  map (uncurry f) (List.combine al bl).

Definition opp_matrix {NAN: Nans}{t:type} (m: matrix t) : matrix t :=
  map (map (@BOPP NAN t)) m.

Definition matrix_add {NAN: Nans}{t} : matrix t matrix t matrix t :=
  map2 (map2 (@BPLUS _ t)).

Definition vector_add {NAN: Nans}{t:type} (v1 v2 : vector t) :=
  map2 (@BPLUS _ t) v1 v2.

Definition vector_sub {NAN: Nans}{t:type} (v1 v2 : vector t) :=
  map2 (@BMINUS _ t) v1 v2.

Definition matrix_index {t} (m: matrix t) (i j: nat) :=
 nth j (nth i m nil) (Zconst t 0).

Definition matrix_by_index {t} (rows cols: nat)
          (f: nat nat ftype t) : matrix t :=
     map (fun imap (f i) (seq 0 cols)) (seq 0 rows).

Definition matrix_rows_nat {t} (m: matrix t) := length m.

Definition matrix_cols_nat {t} (m: matrix t) cols :=
    Forall (fun rlength r = cols) m.

Lemma matrix_by_index_rows:
   {t} rows cols (f: nat nat ftype t),
  matrix_rows_nat (matrix_by_index rows cols f) = rows.

Fixpoint faster_transpose' {t} (row: list t) (A: list (list t)) : list (list t) :=
 match row with
 | (x::row') ⇒ map (hd x) A :: faster_transpose' row' (map (@tl t) A)
 | _nil
 end.

Definition faster_matrix_transpose {t} (A: list (list t)) : list (list t) :=
 match A with row::_faster_transpose' row A | _nil end.

Definition matrix_transpose {t} (A: matrix t) : matrix t :=
 match A with nilnil
 | row::_matrix_by_index (length row) (matrix_rows_nat A)
               (fun i jmatrix_index A j i)
 end.

Lemma faster_matrix_transpose_correct:
   {t} cols (A: matrix t),
  matrix_cols A cols
  faster_matrix_transpose A = matrix_transpose A.

Local Open Scope nat.

Lemma matrix_by_index_cols:
   {t} rows cols (f: nat nat ftype t),
  matrix_cols_nat (matrix_by_index rows cols f) cols.

Lemma nth_map_seq:
   {A} (f: nat A) d (i n: nat), i < n nth i (map f (seq 0 n)) d = f i.

Lemma matrix_by_index_index:
   {t} rows cols (f: nat nat ftype t) i j,
   i < rows j < cols
   matrix_index (matrix_by_index rows cols f) i j = f i j.

Lemma matrix_extensionality_strong:
   {t} (m1 m2: matrix t) cols,
  matrix_rows_nat m1 = matrix_rows_nat m2
  matrix_cols_nat m1 cols matrix_cols_nat m2 cols
  ( i j, i < matrix_rows_nat m1 j < cols
        matrix_index m1 i j = matrix_index m2 i j)
    m1 = m2.

Lemma matrix_extensionality:
   {t} (m1 m2: matrix t) cols,
  matrix_rows_nat m1 = matrix_rows_nat m2
  matrix_cols_nat m1 cols matrix_cols_nat m2 cols
  ( i j, i < matrix_rows_nat m1 j < cols
        feq (matrix_index m1 i j) (matrix_index m2 i j))
  Forall2 (Forall2 feq) m1 m2.

Lemma matrix_index_prop:
  {t} (P: ftype t Prop) (m: matrix t) (cols i j : nat),
    matrix_cols_nat m cols
    Forall (Forall P) m
    i < matrix_rows_nat m j < cols
    P (matrix_index m i j).

Lemma all_nth_eq:
  {A} d (al bl: list A),
  length al = length bl
  ( i, i < length al nth i al d = nth i bl d)
  al=bl.

#[export] Instance zerof {t} : Inhabitant (ftype t) := (Zconst t 0).

Lemma norm2_snoc:
   {NAN: Nans}{t} (al: vector t) (x: ftype t),
   norm2 (al ++ [x]) = BFMA x x (norm2 al).

Lemma dotprod_congr {NAN: Nans}{t} (x x' y y' : vector t):
 Forall2 strict_feq x x'
 Forall2 strict_feq y y'
 length x = length y
 feq (dotprod x y) (dotprod x' y').

Lemma norm2_congr:
   {NAN: Nans} {t} (x x': vector t),
           Forall2 feq x x'
           feq (norm2 x) (norm2 x').

Local Open Scope Z.

Lemma Znth_vector_sub:
  {NAN: Nans}{t} i (x y: vector t) , Zlength x = Zlength y
   0 i < Zlength x
   Znth i (vector_sub x y) = BMINUS (Znth i x) (Znth i y).

Lemma vector_sub_congr: {NAN: Nans} {t} (x x' y y': vector t),
  Forall2 feq x x' Forall2 feq y y'
  Forall2 feq (vector_sub x y) (vector_sub x' y').

Lemma norm2_loose_congr:
  {NAN: Nans}{t} (x x': vector t), Forall2 feq x x' feq (norm2 x) (norm2 x').

Lemma nth_map_inrange {A} (d': A) {B: Type}:
   (f: A B) i al d,
   (i < length al)%nat
   nth i (map f al) d = f (nth i al d').

Lemma finite_dotprod_e: {NAN: Nans}{t} (x y: vector t),
  Zlength x = Zlength y
  finite (dotprod x y) Forall finite x Forall finite y.

Lemma finite_norm2_e: {NAN: Nans}{t} (x: vector t),
  finite (norm2 x) Forall finite x.

Lemma matrix_by_index_prop:
  {t} (f: nat nat ftype t) (P: ftype t Prop) rows cols,
  P (Zconst t 0)
  ( i j, (i < rows)%nat (j < cols)%nat P (f i j))
  Forall (Forall P) (matrix_by_index rows cols f).

Lemma Zmatrix_cols_nat:
  {t} (m: matrix t) cols,
  matrix_cols_nat m cols matrix_cols m (Z.of_nat cols).

Lemma Zlength_seq: lo n, Zlength (seq lo n) = Z.of_nat n.
#[export] Hint Rewrite Zlength_seq : sublist rep_lia.

Lemma Zmatrix_rows_nat: {t} (m: matrix t), Z.of_nat (matrix_rows_nat m) = matrix_rows m.

Add Parametric Morphism {NAN: Nans}{t: type}: (@norm2 _ t)
  with signature Forall2 feq ==> feq
 as norm2_mor.

Add Parametric Morphism {NAN: Nans}{t: type}: (@vector_sub _ t)
  with signature Forall2 feq ==> Forall2 feq ==> Forall2 feq
  as vector_sub_mor.

Add Parametric Morphism {T: Type} (rel: relation T): (@Zlength T)
  with signature Forall2 rel ==> eq
  as Zlength_mor.

Add Parametric Morphism {NAN: Nans}{t}: (@dotprod _ t)
 with signature Forall2 feq ==> Forall2 feq ==> feq
 as dotprod_mor.

Add Parametric Morphism {NAN: Nans} {t}: (@matrix_vector_mult _ t)
 with signature Forall2 (Forall2 feq) ==> Forall2 feq ==> Forall2 feq
 as matrix_vector_mult_mor.