QuantumLib.Matrix


In this file, we define matrices and prove many basic facts from linear algebra

Require Import Psatz.
Require Import String.
Require Import Program.
Require Export Complex.
Require Import List.


Matrix definitions and infrastructure



Declare Scope matrix_scope.
Delimit Scope matrix_scope with M.
Open Scope matrix_scope.

Local Open Scope nat_scope.

Definition Matrix (m n : nat) := nat nat C.

Definition WF_Matrix {m n: nat} (A : Matrix m n) : Prop :=
   x y, x m y n A x y = C0.

Notation Vector n := (Matrix n 1).

Notation Square n := (Matrix n n).

Equality via functional extensionality
Ltac prep_matrix_equality :=
  let x := fresh "x" in
  let y := fresh "y" in
  apply functional_extensionality; intros x;
  apply functional_extensionality; intros y.

Matrix equivalence

Definition mat_equiv {m n : nat} (A B : Matrix m n) : Prop :=
   i j, i < m j < n A i j = B i j.

Infix "==" := mat_equiv (at level 70) : matrix_scope.

Lemma mat_equiv_refl : m n (A : Matrix m n), mat_equiv A A.

Lemma mat_equiv_eq : {m n : nat} (A B : Matrix m n),
  WF_Matrix A
  WF_Matrix B
  A == B
  A = B.

Printing

Parameter print_C : C string.
Fixpoint print_row {m n} i j (A : Matrix m n) : string :=
  match j with
  | 0 ⇒ "\n"
  | S j'print_C (A i j') ++ ", " ++ print_row i j' A
  end.
Fixpoint print_rows {m n} i j (A : Matrix m n) : string :=
  match i with
  | 0 ⇒ ""
  | S i'print_row i' n A ++ print_rows i' n A
  end.
Definition print_matrix {m n} (A : Matrix m n) : string :=
  print_rows m n A.

2D list representation

Definition list2D_to_matrix (l : list (list C)) :
  Matrix (length l) (length (hd [] l)) :=
  (fun x ynth y (nth x l []) 0%R).

Lemma WF_list2D_to_matrix : m n li,
    length li = m
    ( li', In li' li length li' = n)
    @WF_Matrix m n (list2D_to_matrix li).

Example
Definition M23 : Matrix 2 3 :=
  fun x y
  match (x, y) with
  | (0, 0) ⇒ 1%R
  | (0, 1) ⇒ 2%R
  | (0, 2) ⇒ 3%R
  | (1, 0) ⇒ 4%R
  | (1, 1) ⇒ 5%R
  | (1, 2) ⇒ 6%R
  | _C0
  end.

Definition M23' : Matrix 2 3 :=
  list2D_to_matrix
  ([ [RtoC 1; RtoC 2; RtoC 3];
    [RtoC 4; RtoC 5; RtoC 6] ]).

Lemma M23eq : M23 = M23'.

Operands and operations


Definition Zero {m n : nat} : Matrix m n := fun x y ⇒ 0%R.

Definition I (n : nat) : Square n :=
  (fun x yif (x =? y) && (x <? n) then C1 else C0).


Definition I__inf := fun x yif x =? y then C1 else C0.
Notation "I∞" := I__inf : matrix_scope.

Definition trace {n : nat} (A : Square n) :=
  big_sum (fun xA x x) n.

Definition scale {m n : nat} (r : C) (A : Matrix m n) : Matrix m n :=
  fun x y ⇒ (r × A x y)%C.

Definition dot {n : nat} (A : Vector n) (B : Vector n) : C :=
  big_sum (fun xA x 0 × B x 0)%C n.

Definition Mplus {m n : nat} (A B : Matrix m n) : Matrix m n :=
  fun x y ⇒ (A x y + B x y)%C.

Definition Mopp {m n : nat} (A : Matrix m n) : Matrix m n :=
  scale (-C1) A.

Definition Mminus {m n : nat} (A B : Matrix m n) : Matrix m n :=
  Mplus A (Mopp B).

Definition Mmult {m n o : nat} (A : Matrix m n) (B : Matrix n o) : Matrix m o :=
  fun x zbig_sum (fun yA x y × B y z)%C n.

Definition kron {m n o p : nat} (A : Matrix m n) (B : Matrix o p) :
  Matrix (m×o) (n×p) :=
  fun x yCmult (A (x / o) (y / p)) (B (x mod o) (y mod p)).

Definition direct_sum {m n o p : nat} (A : Matrix m n) (B : Matrix o p) :
  Matrix (m+o) (n+p) :=
  fun x yif (x <? m) || (y <? n) then A x y else B (x - m) (y - n).

Definition transpose {m n} (A : Matrix m n) : Matrix n m :=
  fun x yA y x.

Definition adjoint {m n} (A : Matrix m n) : Matrix n m :=
  fun x y(A y x)^*.

Definition inner_product {n} (u v : Vector n) : C :=
  Mmult (adjoint u) (v) 0 0.

Definition outer_product {n} (u v : Vector n) : Square n :=
  Mmult u (adjoint v).

Kronecker of n copies of A
Fixpoint kron_n n {m1 m2} (A : Matrix m1 m2) : Matrix (m1^n) (m2^n) :=
  match n with
  | 0 ⇒ I 1
  | S n'kron (kron_n n' A) A
  end.

Kronecker product of a list
Fixpoint big_kron {m n} (As : list (Matrix m n)) :
  Matrix (m^(length As)) (n^(length As)) :=
  match As with
  | []I 1
  | A :: As'kron A (big_kron As')
  end.

Product of n copies of A
Fixpoint Mmult_n n {m} (A : Square m) : Square m :=
  match n with
  | 0 ⇒ I m
  | S n'Mmult A (Mmult_n n' A)
  end.

Direct sum of n copies of A
Fixpoint direct_sum_n n {m1 m2} (A : Matrix m1 m2) : Matrix (n×m1) (n×m2) :=
  match n with
  | 0 ⇒ @Zero 0 0
  | S n'direct_sum A (direct_sum_n n' A)
  end.

Showing that M is a vector space

Program Instance M_is_monoid : n m, Monoid (Matrix n m) :=
  { Gzero := @Zero n m
  ; Gplus := Mplus
  }.
Solve All Obligations with program_simpl; prep_matrix_equality; lca.

Program Instance M_is_group : n m, Group (Matrix n m) :=
  { Gopp := Mopp }.
Solve All Obligations with program_simpl; prep_matrix_equality; lca.

Program Instance M_is_comm_group : n m, Comm_Group (Matrix n m).
Solve All Obligations with program_simpl; prep_matrix_equality; lca.

Program Instance M_is_vector_space : n m, Vector_Space (Matrix n m) C :=
  { Vscale := scale }.
Solve All Obligations with program_simpl; prep_matrix_equality; lca.

Notations
Infix "∘" := dot (at level 40, left associativity) : matrix_scope.
Infix ".+" := Mplus (at level 50, left associativity) : matrix_scope.
Infix ".*" := scale (at level 40, left associativity) : matrix_scope.
Infix "×" := Mmult (at level 40, left associativity) : matrix_scope.
Infix "⊗" := kron (at level 40, left associativity) : matrix_scope.
Infix ".⊕" := direct_sum (at level 20) : matrix_scope. Infix "≡" := mat_equiv (at level 70) : matrix_scope.
Notation "A ⊤" := (transpose A) (at level 0) : matrix_scope.
Notation "A †" := (adjoint A) (at level 0) : matrix_scope.
Notation Σ := (@big_sum C C_is_monoid). Notation "n ⨂ A" := (kron_n n A) (at level 30, no associativity) : matrix_scope.
Notation "⨂ A" := (big_kron A) (at level 60): matrix_scope.
Notation "n ⨉ A" := (Mmult_n n A) (at level 30, no associativity) : matrix_scope.
Notation "⟨ u , v ⟩" := (inner_product u v) (at level 0) : matrix_scope.
#[export] Hint Unfold Zero I trace dot Mplus scale Mmult kron mat_equiv transpose
            adjoint : U_db.

Ltac destruct_m_1 :=
  match goal with
  | [ |- context[match ?x with
                 | 0 ⇒ _
                 | S __
                 end] ] ⇒ is_var x; destruct x
  end.
Ltac destruct_m_eq := repeat (destruct_m_1; simpl).

Ltac lma :=
  autounfold with U_db;
  prep_matrix_equality;
  destruct_m_eq;
  lca.

Ltac solve_end :=
  match goal with
  | H : lt _ O |- _apply Nat.nlt_0_r in H; contradict H
  end.

Ltac by_cell :=
  intros;
  let i := fresh "i" in
  let j := fresh "j" in
  let Hi := fresh "Hi" in
  let Hj := fresh "Hj" in
  intros i j Hi Hj; try solve_end;
  repeat (destruct i as [|i]; simpl; [|apply lt_S_n in Hi]; try solve_end); clear Hi;
  repeat (destruct j as [|j]; simpl; [|apply lt_S_n in Hj]; try solve_end); clear Hj.

Ltac lma' :=
  apply mat_equiv_eq;
  repeat match goal with
  | [ |- WF_Matrix (?A) ] ⇒ auto with wf_db
  | [ |- mat_equiv (?A) (?B) ] ⇒ by_cell; try lca
  end.

Lemma kron_simplify : (n m o p : nat) (a b : Matrix n m) (c d : Matrix o p),
    a = b c = d a c = b d.

Lemma n_kron_simplify : (n m : nat) (a b : Matrix n m) (n m : nat),
    a = b n = m n a = m b.

Lemma Mtranspose_simplify : (n m : nat) (a b : Matrix n m),
    a = b a = b.

Lemma Madjoint_simplify : (n m : nat) (a b : Matrix n m),
    a = b a = b.

Lemma Mmult_simplify : (n m o : nat) (a b : Matrix n m) (c d : Matrix m o),
    a = b c = d a × c = b × d.

Lemma Mmult_n_simplify : (n : nat) (a b : Square n) (c d : nat),
    a = b c = d c a = d b.

Lemma dot_simplify : (n : nat) (a b c d: Vector n),
    a = b c = d a c = b c.

Lemma Mplus_simplify : (n m: nat) (a b : Matrix n m) (c d : Matrix n m),
    a = b c = d a .+ c = b .+ d.

Lemma Mscale_simplify : (n m: nat) (a b : Matrix n m) (c d : C),
    a = b c = d c .* a = d .* b.

Proofs about well-formedness


Lemma WF_Matrix_dim_change : (m n m' n' : nat) (A : Matrix m n),
  m = m'
  n = n'
  @WF_Matrix m n A
  @WF_Matrix m' n' A.

Lemma WF_Zero : m n : nat, WF_Matrix (@Zero m n).

Lemma WF_I : n : nat, WF_Matrix (I n).

Lemma WF_I1 : WF_Matrix (I 1).

Lemma WF_scale : {m n : nat} (r : C) (A : Matrix m n),
  WF_Matrix A WF_Matrix (scale r A).

Lemma WF_plus : {m n} (A B : Matrix m n),
  WF_Matrix A WF_Matrix B WF_Matrix (A .+ B).

Lemma WF_mult : {m n o : nat} (A : Matrix m n) (B : Matrix n o),
  WF_Matrix A WF_Matrix B WF_Matrix (A × B).

Lemma WF_kron : {m n o p q r : nat} (A : Matrix m n) (B : Matrix o p),
                  q = m × o r = n × p
                  WF_Matrix A WF_Matrix B @WF_Matrix q r (A B).

Lemma WF_direct_sum : {m n o p q r : nat} (A : Matrix m n) (B : Matrix o p),
                  q = m + o r = n + p
                  WF_Matrix A WF_Matrix B @WF_Matrix q r (A .⊕ B).

Lemma WF_transpose : {m n : nat} (A : Matrix m n),
                     WF_Matrix A WF_Matrix A.

Lemma WF_adjoint : {m n : nat} (A : Matrix m n),
      WF_Matrix A WF_Matrix A.

Lemma WF_outer_product : {n} (u v : Vector n),
    WF_Matrix u
    WF_Matrix v
    WF_Matrix (outer_product u v).

Lemma WF_kron_n : n {m1 m2} (A : Matrix m1 m2),
   WF_Matrix A WF_Matrix (kron_n n A).

Lemma WF_big_kron : n m (l : list (Matrix m n)) (A : Matrix m n),
                        ( i, WF_Matrix (nth i l A))
                         WF_Matrix ( l).

Lemma WF_big_kron' : n m (l : list (Matrix m n)),
                        ( A, In A l WF_Matrix A)
                         WF_Matrix ( l).

Lemma WF_Mmult_n : n {m} (A : Square m),
   WF_Matrix A WF_Matrix (Mmult_n n A).

Lemma WF_direct_sum_n : n {m1 m2} (A : Matrix m1 m2),
   WF_Matrix A WF_Matrix (direct_sum_n n A).

Lemma WF_Msum : d1 d2 n (f : nat Matrix d1 d2),
  ( i, (i < n)%nat WF_Matrix (f i))
  WF_Matrix (big_sum f n).

Local Close Scope nat_scope.

Tactics for showing well-formedness


Local Open Scope nat.
Local Open Scope R.
Local Open Scope C.

Ltac show_wf :=
  unfold WF_Matrix;
  let x := fresh "x" in
  let y := fresh "y" in
  let H := fresh "H" in
  intros x y [H | H];
  apply le_plus_minus in H; rewrite H;
  cbv;
  destruct_m_eq;
  try lca.

#[export] Hint Resolve WF_Zero WF_I WF_I1 WF_mult WF_plus WF_scale WF_transpose
     WF_adjoint WF_outer_product WF_big_kron WF_kron_n WF_kron
     WF_Mmult_n WF_Msum : wf_db.
#[export] Hint Extern 2 (_ = _) ⇒ unify_pows_two : wf_db.

Basic matrix lemmas


Lemma WF0_Zero_l : (n : nat) (A : Matrix 0%nat n), WF_Matrix A A = Zero.

Lemma WF0_Zero_r : (n : nat) (A : Matrix n 0%nat), WF_Matrix A A = Zero.

Lemma WF0_Zero : (A : Matrix 0%nat 0%nat), WF_Matrix A A = Zero.

Lemma I0_Zero : I 0 = Zero.

Lemma trace_plus_dist : (n : nat) (A B : Square n),
    trace (A .+ B) = (trace A + trace B)%C.

Lemma trace_mult_dist : n p (A : Square n), trace (p .* A) = (p × trace A)%C.

Lemma Mplus_0_l : (m n : nat) (A : Matrix m n), Zero .+ A = A.

Lemma Mplus_0_r : (m n : nat) (A : Matrix m n), A .+ Zero = A.

Lemma Mmult_0_l : (m n o : nat) (A : Matrix n o), @Zero m n × A = Zero.

Lemma Mmult_0_r : (m n o : nat) (A : Matrix m n), A × @Zero n o = Zero.

Lemma Mmult_1_l_gen: (m n : nat) (A : Matrix m n) (x z k : nat),
  (k m)%nat
  ((k x)%nat big_sum (fun y : natI m x y × A y z) k = 0)
  ((k > x)%nat big_sum (fun y : natI m x y × A y z) k = A x z).

Lemma Mmult_1_l_mat_eq : (m n : nat) (A : Matrix m n), I m × A == A.

Lemma Mmult_1_l: (m n : nat) (A : Matrix m n),
  WF_Matrix A I m × A = A.

Lemma Mmult_1_r_gen: (m n : nat) (A : Matrix m n) (x z k : nat),
  (k n)%nat
  ((k z)%nat big_sum (fun y : natA x y × (I n) y z) k = 0)
  ((k > z)%nat big_sum (fun y : natA x y × (I n) y z) k = A x z).

Lemma Mmult_1_r_mat_eq : (m n : nat) (A : Matrix m n), A × I n A.

Lemma Mmult_1_r: (m n : nat) (A : Matrix m n),
  WF_Matrix A A × I n = A.

Lemma Mmult_inf_l : (m n : nat) (A : Matrix m n),
  WF_Matrix A I × A = A.

Lemma Mmult_inf_r : (m n : nat) (A : Matrix m n),
  WF_Matrix A A × I = A.

Lemma kron_0_l : (m n o p : nat) (A : Matrix o p),
  @Zero m n A = Zero.

Lemma kron_0_r : (m n o p : nat) (A : Matrix m n),
   A @Zero o p = Zero.

Lemma kron_1_r : (m n : nat) (A : Matrix m n), A I 1 = A.

Lemma kron_1_l : (m n : nat) (A : Matrix m n),
  WF_Matrix A I 1 A = A.

Theorem transpose_involutive : (m n : nat) (A : Matrix m n), (A⊤)⊤ = A.

Theorem adjoint_involutive : (m n : nat) (A : Matrix m n), A†† = A.

Lemma id_transpose_eq : n, (I n)⊤ = (I n).

Lemma zero_transpose_eq : m n, (@Zero m n)⊤ = @Zero m n.

Lemma id_adjoint_eq : n, (I n)† = (I n).

Lemma zero_adjoint_eq : m n, (@Zero m n)† = @Zero n m.

Theorem Mplus_comm : (m n : nat) (A B : Matrix m n), A .+ B = B .+ A.

Theorem Mplus_assoc : (m n : nat) (A B C : Matrix m n), A .+ B .+ C = A .+ (B .+ C).

Theorem Mmult_assoc : {m n o p : nat} (A : Matrix m n) (B : Matrix n o)
  (C: Matrix o p), A × B × C = A × (B × C).

Lemma Mmult_plus_distr_l : (m n o : nat) (A : Matrix m n) (B C : Matrix n o),
                           A × (B .+ C) = A × B .+ A × C.

Lemma Mmult_plus_distr_r : (m n o : nat) (A B : Matrix m n) (C : Matrix n o),
                           (A .+ B) × C = A × C .+ B × C.

Lemma kron_plus_distr_l : (m n o p : nat) (A : Matrix m n) (B C : Matrix o p),
                           A (B .+ C) = A B .+ A C.

Lemma kron_plus_distr_r : (m n o p : nat) (A B : Matrix m n) (C : Matrix o p),
                           (A .+ B) C = A C .+ B C.

Lemma Mscale_0_l : (m n : nat) (A : Matrix m n), C0 .* A = Zero.

Lemma Mscale_0_r : (m n : nat) (c : C), c .* @Zero m n = Zero.

Lemma Mscale_1_l : (m n : nat) (A : Matrix m n), C1 .* A = A.

Lemma Mscale_1_r : (n : nat) (c : C),
    c .* I n = fun x yif (x =? y) && (x <? n) then c else C0.

Lemma Mscale_assoc : (m n : nat) (x y : C) (A : Matrix m n),
  x .* (y .* A) = (x × y) .* A.

Lemma Mscale_div : {n m} (c : C) (A B : Matrix n m),
  c C0 c .* A = c .* B A = B.

Lemma Mscale_plus_distr_l : (m n : nat) (x y : C) (A : Matrix m n),
  (x + y) .* A = x .* A .+ y .* A.

Lemma Mscale_plus_distr_r : (m n : nat) (x : C) (A B : Matrix m n),
  x .* (A .+ B) = x .* A .+ x .* B.

Lemma Mscale_mult_dist_l : (m n o : nat) (x : C) (A : Matrix m n) (B : Matrix n o),
    ((x .* A) × B) = x .* (A × B).

Lemma Mscale_mult_dist_r : (m n o : nat) (x : C) (A : Matrix m n) (B : Matrix n o),
    (A × (x .* B)) = x .* (A × B).

Lemma Mscale_kron_dist_l : (m n o p : nat) (x : C) (A : Matrix m n) (B : Matrix o p),
    ((x .* A) B) = x .* (A B).

Lemma Mscale_kron_dist_r : (m n o p : nat) (x : C) (A : Matrix m n) (B : Matrix o p),
    (A (x .* B)) = x .* (A B).

Lemma Mscale_trans : (m n : nat) (x : C) (A : Matrix m n),
    (x .* A)⊤ = x .* A.

Lemma Mscale_adj : (m n : nat) (x : C) (A : Matrix m n),
    (x .* A)† = x^* .* A.

Lemma Mplus_transpose : (m n : nat) (A : Matrix m n) (B : Matrix m n),
  (A .+ B)⊤ = A .+ B.

Lemma Mmult_transpose : (m n o : nat) (A : Matrix m n) (B : Matrix n o),
      (A × B)⊤ = B × A.

Lemma kron_transpose : (m n o p : nat) (A : Matrix m n) (B : Matrix o p ),
  (A B)⊤ = A B.

Lemma Mplus_adjoint : (m n : nat) (A : Matrix m n) (B : Matrix m n),
  (A .+ B)† = A .+ B.

Lemma Mmult_adjoint : {m n o : nat} (A : Matrix m n) (B : Matrix n o),
      (A × B)† = B × A.

Lemma kron_adjoint : {m n o p : nat} (A : Matrix m n) (B : Matrix o p),
  (A B)† = A B.

Lemma id_kron : (m n : nat), I m I n = I (m × n).

Local Open Scope nat_scope.

Lemma div_mod : (x y z : nat), (x / y) mod z = (x mod (y × z)) / y.

Lemma sub_mul_mod :
   x y z,
    y × z x
    (x - y × z) mod z = x mod z.

Lemma mod_product : x y z, y 0 x mod (y × z) mod z = x mod z.

Lemma kron_assoc_mat_equiv : {m n p q r s : nat}
  (A : Matrix m n) (B : Matrix p q) (C : Matrix r s),
  (A B C) == A (B C).

Lemma kron_assoc : {m n p q r s : nat}
  (A : Matrix m n) (B : Matrix p q) (C : Matrix r s),
  WF_Matrix A WF_Matrix B WF_Matrix C
  (A B C) = A (B C).

Lemma kron_mixed_product : {m n o p q r : nat} (A : Matrix m n) (B : Matrix p q )
  (C : Matrix n o) (D : Matrix q r), (A B) × (C D) = (A × C) (B × D).


Lemma kron_mixed_product' : (m n n' o p q q' r mp nq or: nat)
    (A : Matrix m n) (B : Matrix p q) (C : Matrix n' o) (D : Matrix q' r),
    n = n' q = q'
    mp = m × p nq = n × q or = o × r
  (@Mmult mp nq or (@kron m n p q A B) (@kron n' o q' r C D)) =
  (@kron m o p r (@Mmult m n o A C) (@Mmult p q r B D)).

Lemma direct_sum_assoc : {m n p q r s : nat}
  (A : Matrix m n) (B : Matrix p q) (C : Matrix r s),
  (A .⊕ B .⊕ C) = A .⊕ (B .⊕ C).

Lemma outer_product_eq : m (φ ψ : Matrix m 1),
 φ = ψ outer_product φ φ = outer_product ψ ψ.

Lemma outer_product_kron : m n (φ : Matrix m 1) (ψ : Matrix n 1),
    outer_product φ φ outer_product ψ ψ = outer_product (φ ψ) (φ ψ).

Lemma big_kron_app : {n m} (l1 l2 : list (Matrix n m)),
  ( i, WF_Matrix (nth i l1 (@Zero n m)))
  ( i, WF_Matrix (nth i l2 (@Zero n m)))
   (l1 ++ l2) = ( l1) ( l2).

Lemma kron_n_assoc :
   n {m1 m2} (A : Matrix m1 m2), WF_Matrix A (S n) A = A (n A).

Lemma kron_n_adjoint : n {m1 m2} (A : Matrix m1 m2),
  WF_Matrix A (n A)† = n A.

Lemma kron_n_transpose : (m n o : nat) (A : Matrix m n),
  (o A)⊤ = o (A).

Lemma Mscale_kron_n_distr_r : {m1 m2} n α (A : Matrix m1 m2),
  n (α .* A) = (α ^ n) .* (n A).

Lemma kron_n_mult : {m1 m2 m3} n (A : Matrix m1 m2) (B : Matrix m2 m3),
  n A × n B = n (A × B).

Lemma kron_n_I : n, n I 2 = I (2 ^ n).

Lemma Mmult_n_kron_distr_l : {m n} i (A : Square m) (B : Square n),
  i (A B) = (i A) (i B).

Lemma Mmult_n_1_l : {n} (A : Square n),
  WF_Matrix A
  1 A = A.

Lemma Mmult_n_1_r : n i,
  i (I n) = I n.

Lemma Mmult_n_eigenvector : {n} (A : Square n) (ψ : Vector n) λ i,
  WF_Matrix ψ A × ψ = λ .* ψ
  i A × ψ = (λ ^ i) .* ψ.

Summation lemmas specific to matrices


Lemma kron_Msum_distr_l :
   {d1 d2 d3 d4} n (f : nat Matrix d1 d2) (A : Matrix d3 d4),
  A big_sum f n = big_sum (fun iA f i) n.

Lemma kron_Msum_distr_r :
   {d1 d2 d3 d4} n (f : nat Matrix d1 d2) (A : Matrix d3 d4),
  big_sum f n A = big_sum (fun if i A) n.

Lemma Mmult_Msum_distr_l : {d1 d2 m} n (f : nat Matrix d1 d2) (A : Matrix m d1),
  A × big_sum f n = big_sum (fun iA × f i) n.

Lemma Mmult_Msum_distr_r : {d1 d2 m} n (f : nat Matrix d1 d2) (A : Matrix d2 m),
  big_sum f n × A = big_sum (fun if i × A) n.

Lemma Mscale_Msum_distr_r : {d1 d2} n (c : C) (f : nat Matrix d1 d2),
  big_sum (fun ic .* (f i)) n = c .* big_sum f n.

Lemma Mscale_Msum_distr_l : {d1 d2} n (f : nat C) (A : Matrix d1 d2),
  big_sum (fun i(f i) .* A) n = big_sum f n .* A.

Lemma Msum_constant : {d1 d2} n (A : Matrix d1 d2), big_sum (fun _A) n = INR n .* A.

Lemma Msum_adjoint : {d1 d2} n (f : nat Matrix d1 d2),
  (big_sum f n)† = big_sum (fun i(f i)†) n.

Lemma Msum_Csum : {d1 d2} n (f : nat Matrix d1 d2) i j,
  (big_sum f n) i j = big_sum (fun x ⇒ (f x) i j) n.

Lemma Msum_plus : n {d1 d2} (f g : nat Matrix d1 d2),
    big_sum (fun xf x .+ g x) n = big_sum f n .+ big_sum g n.

Defining matrix altering/col operations


Definition get_vec {n m} (i : nat) (S : Matrix n m) : Vector n :=
  fun x y ⇒ (if (y =? 0) then S x i else C0).

Definition get_row {n m} (i : nat) (S : Matrix n m) : Matrix 1 m :=
  fun x y ⇒ (if (x =? 0) then S i y else C0).

Definition reduce_row {n m} (A : Matrix (S n) m) (row : nat) : Matrix n m :=
  fun x yif x <? row
             then A x y
             else A (1 + x) y.

Definition reduce_col {n m} (A : Matrix n (S m)) (col : nat) : Matrix n m :=
  fun x yif y <? col
             then A x y
             else A x (1 + y).

Definition reduce_vecn {n} (v : Vector (S n)) : Vector n :=
  fun x yif x <? n
             then v x y
             else v (1 + x) y.

Definition reduce {n} (A : Square (S n)) (row col : nat) : Square n :=
  fun x y ⇒ (if x <? row
              then (if y <? col
                    then A x y
                    else A x (1+y))
              else (if y <? col
                    then A (1+x) y
                    else A (1+x) (1+y))).

Definition col_append {n m} (T : Matrix n m) (v : Vector n) : Matrix n (S m) :=
  fun i jif (j =? m) then v i 0 else T i j.

Definition row_append {n m} (T : Matrix n m) (v : Matrix 1 m) : Matrix (S n) m :=
  fun i jif (i =? n) then v 0 j else T i j.

Definition smash {n m1 m2} (T1 : Matrix n m1) (T2 : Matrix n m2) : Matrix n (m1 + m2) :=
  fun i jif j <? m1 then T1 i j else T2 i (j - m1).

Definition col_wedge {n m} (T : Matrix n m) (v : Vector n) (spot : nat) : Matrix n (S m) :=
  fun i jif j <? spot
             then T i j
             else if j =? spot
                  then v i 0
                  else T i (j-1).

Definition row_wedge {n m} (T : Matrix n m) (v : Matrix 1 m) (spot : nat) : Matrix (S n) m :=
  fun i jif i <? spot
             then T i j
             else if i =? spot
                  then v 0 j
                  else T (i-1) j.

Definition col_swap {n m : nat} (S : Matrix n m) (x y : nat) : Matrix n m :=
  fun i jif (j =? x)
             then S i y
             else if (j =? y)
                  then S i x
                  else S i j.

Definition row_swap {n m : nat} (S : Matrix n m) (x y : nat) : Matrix n m :=
  fun i jif (i =? x)
             then S y j
             else if (i =? y)
                  then S x j
                  else S i j.

Definition col_scale {n m : nat} (S : Matrix n m) (col : nat) (a : C) : Matrix n m :=
  fun i jif (j =? col)
             then (a × S i j)%C
             else S i j.

Definition row_scale {n m : nat} (S : Matrix n m) (row : nat) (a : C) : Matrix n m :=
  fun i jif (i =? row)
             then (a × S i j)%C
             else S i j.

Definition col_add {n m : nat} (S : Matrix n m) (col to_add : nat) (a : C) : Matrix n m :=
  fun i jif (j =? col)
             then (S i j + a × S i to_add)%C
             else S i j.

Definition row_add {n m : nat} (S : Matrix n m) (row to_add : nat) (a : C) : Matrix n m :=
  fun i jif (i =? row)
             then (S i j + a × S to_add j)%C
             else S i j.

Definition gen_new_vec (n m : nat) (S : Matrix n m) (as' : Vector m) : Vector n :=
  big_sum (fun i(as' i 0) .* (get_vec i S)) m.

Definition gen_new_row (n m : nat) (S : Matrix n m) (as' : Matrix 1 n) : Matrix 1 m :=
  big_sum (fun i(as' 0 i) .* (get_row i S)) n.

Definition col_add_many {n m} (col : nat) (as' : Vector m) (S : Matrix n m) : Matrix n m :=
  fun i jif (j =? col)
             then (S i j + (gen_new_vec n m S as') i 0)%C
             else S i j.

Definition row_add_many {n m} (row : nat) (as' : Matrix 1 n) (S : Matrix n m) : Matrix n m :=
  fun i jif (i =? row)
             then (S i j + (gen_new_row n m S as') 0 j)%C
             else S i j.

Definition col_add_each {n m} (col : nat) (as' : Matrix 1 m) (S : Matrix n m) : Matrix n m :=
  S .+ ((get_vec col S) × as').

Definition row_add_each {n m} (row : nat) (as' : Vector n) (S : Matrix n m) : Matrix n m :=
  S .+ (as' × get_row row S).

Definition make_col_zero {n m} (col : nat) (S : Matrix n m) : Matrix n m :=
  fun i jif (j =? col)
             then C0
             else S i j.

Definition make_row_zero {n m} (row : nat) (S : Matrix n m) : Matrix n m :=
  fun i jif (i =? row)
             then C0
             else S i j.

Definition make_WF {n m} (S : Matrix n m) : Matrix n m :=
  fun i jif (i <? n) && (j <? m) then S i j else C0.

proving lemmas about these new functions

Lemma WF_get_vec : {n m} (i : nat) (S : Matrix n m),
  WF_Matrix S WF_Matrix (get_vec i S).

Lemma WF_get_row : {n m} (i : nat) (S : Matrix n m),
  WF_Matrix S WF_Matrix (get_row i S).

Lemma WF_reduce_row : {n m} (row : nat) (A : Matrix (S n) m),
  row < (S n) WF_Matrix A WF_Matrix (reduce_row A row).

Lemma WF_reduce_col : {n m} (col : nat) (A : Matrix n (S m)),
  col < (S m) WF_Matrix A WF_Matrix (reduce_col A col).

Lemma rvn_is_rr_n : {n : nat} (v : Vector (S n)),
  reduce_vecn v = reduce_row v n.

Lemma WF_reduce_vecn : {n} (v : Vector (S n)),
  n 0 WF_Matrix v WF_Matrix (reduce_vecn v).

Lemma reduce_is_redrow_redcol : {n} (A : Square (S n)) (row col : nat),
  reduce A row col = reduce_col (reduce_row A row) col.

Lemma reduce_is_redcol_redrow : {n} (A : Square (S n)) (row col : nat),
  reduce A row col = reduce_row (reduce_col A col) row.

Lemma WF_reduce : {n} (A : Square (S n)) (row col : nat),
  row < S n col < S n WF_Matrix A WF_Matrix (reduce A row col).

Lemma WF_col_swap : {n m : nat} (S : Matrix n m) (x y : nat),
  x < m y < m WF_Matrix S WF_Matrix (col_swap S x y).

Lemma WF_row_swap : {n m : nat} (S : Matrix n m) (x y : nat),
  x < n y < n WF_Matrix S WF_Matrix (row_swap S x y).

Lemma WF_col_scale : {n m : nat} (S : Matrix n m) (x : nat) (a : C),
  WF_Matrix S WF_Matrix (col_scale S x a).

Lemma WF_row_scale : {n m : nat} (S : Matrix n m) (x : nat) (a : C),
  WF_Matrix S WF_Matrix (row_scale S x a).

Lemma WF_col_add : {n m : nat} (S : Matrix n m) (x y : nat) (a : C),
  x < m WF_Matrix S WF_Matrix (col_add S x y a).

Lemma WF_row_add : {n m : nat} (S : Matrix n m) (x y : nat) (a : C),
  x < n WF_Matrix S WF_Matrix (row_add S x y a).

Lemma WF_gen_new_vec : {n m} (S : Matrix n m) (as' : Vector m),
  WF_Matrix S WF_Matrix (gen_new_vec n m S as').

Lemma WF_gen_new_row : {n m} (S : Matrix n m) (as' : Matrix 1 n),
  WF_Matrix S WF_Matrix (gen_new_row n m S as').

Lemma WF_col_add_many : {n m} (col : nat) (as' : Vector m) (S : Matrix n m),
  col < m WF_Matrix S WF_Matrix (col_add_many col as' S).

Lemma WF_row_add_many : {n m} (row : nat) (as' : Matrix 1 n) (S : Matrix n m),
  row < n WF_Matrix S WF_Matrix (row_add_many row as' S).

Lemma WF_col_append : {n m} (T : Matrix n m) (v : Vector n),
  WF_Matrix T WF_Matrix v WF_Matrix (col_append T v).

Lemma WF_row_append : {n m} (T : Matrix n m) (v : Matrix 1 m),
  WF_Matrix T WF_Matrix v WF_Matrix (row_append T v).

Lemma WF_col_wedge : {n m} (T : Matrix n m) (v : Vector n) (spot : nat),
  spot m WF_Matrix T WF_Matrix v WF_Matrix (col_wedge T v spot).

Lemma WF_row_wedge : {n m} (T : Matrix n m) (v : Matrix 1 m) (spot : nat),
  spot n WF_Matrix T WF_Matrix v WF_Matrix (row_wedge T v spot).

Lemma WF_smash : {n m1 m2} (T1 : Matrix n m1) (T2 : Matrix n m2),
  WF_Matrix T1 WF_Matrix T2 WF_Matrix (smash T1 T2).

Lemma WF_col_add_each : {n m} (col : nat) (as' : Matrix 1 m) (S : Matrix n m),
  WF_Matrix S WF_Matrix as' WF_Matrix (col_add_each col as' S).

Lemma WF_row_add_each : {n m} (row : nat) (as' : Vector n) (S : Matrix n m),
  WF_Matrix S WF_Matrix as' WF_Matrix (row_add_each row as' S).

Lemma WF_make_col_zero : {n m} (col : nat) (S : Matrix n m),
  WF_Matrix S WF_Matrix (make_col_zero col S).

Lemma WF_make_row_zero : {n m} (row : nat) (S : Matrix n m),
  WF_Matrix S WF_Matrix (make_row_zero row S).

Lemma WF_make_WF : {n m} (S : Matrix n m), WF_Matrix (make_WF S).

#[export] Hint Resolve WF_get_vec WF_get_row WF_reduce_row WF_reduce_col WF_reduce_vecn WF_reduce : wf_db.
#[export] Hint Resolve WF_col_swap WF_row_swap WF_col_scale WF_row_scale WF_col_add WF_row_add : wf_db.
#[export] Hint Resolve WF_gen_new_vec WF_gen_new_row WF_col_add_many WF_row_add_many : wf_db.
#[export] Hint Resolve WF_col_append WF_row_append WF_row_wedge WF_col_wedge WF_smash : wf_db.
#[export] Hint Resolve WF_col_add_each WF_row_add_each WF_make_col_zero WF_make_row_zero WF_make_WF : wf_db.
#[export] Hint Extern 1 (Nat.lt _ _) ⇒ lia : wf_db.

Lemma get_vec_reduce_col : {n m} (i col : nat) (A : Matrix n (S m)),
  i < col get_vec i (reduce_col A col) = get_vec i A.

Lemma get_vec_conv : {n m} (x y : nat) (S : Matrix n m),
  (get_vec y S) x 0 = S x y.

Lemma get_vec_mult : {n} (i : nat) (A B : Square n),
  A × (get_vec i B) = get_vec i (A × B).

Lemma det_by_get_vec : {n} (A B : Square n),
  ( i, get_vec i A = get_vec i B) A = B.

Lemma col_scale_reduce_col_same : {n m} (T : Matrix n (S m)) (y col : nat) (a : C),
  y = col reduce_col (col_scale T col a) y = reduce_col T y.

Lemma col_swap_reduce_before : {n : nat} (T : Square (S n)) (row col c1 c2 : nat),
  col < (S c1) col < (S c2)
  reduce (col_swap T (S c1) (S c2)) row col = col_swap (reduce T row col) c1 c2.

Lemma col_scale_reduce_before : {n : nat} (T : Square (S n)) (x y col : nat) (a : C),
  y < col reduce (col_scale T col a) x y = col_scale (reduce T x y) (col - 1) a.

Lemma col_scale_reduce_same : {n : nat} (T : Square (S n)) (x y col : nat) (a : C),
  y = col reduce (col_scale T col a) x y = reduce T x y.

Lemma col_scale_reduce_after : {n : nat} (T : Square (S n)) (x y col : nat) (a : C),
  y > col reduce (col_scale T col a) x y = col_scale (reduce T x y) col a.

Lemma mcz_reduce_col_same : {n m} (T : Matrix n (S m)) (col : nat),
  reduce_col (make_col_zero col T) col = reduce_col T col.

Lemma mrz_reduce_row_same : {n m} (T : Matrix (S n) m) (row : nat),
  reduce_row (make_row_zero row T) row = reduce_row T row.

Lemma col_add_many_reduce_col_same : {n m} (T : Matrix n (S m)) (v : Vector (S m))
                                            (col : nat),
  reduce_col (col_add_many col v T) col = reduce_col T col.

Lemma row_add_many_reduce_row_same : {n m} (T : Matrix (S n) m) (v : Matrix 1 (S n))
                                            (row : nat),
  reduce_row (row_add_many row v T) row = reduce_row T row.

Lemma col_wedge_reduce_col_same : {n m} (T : Matrix n m) (v : Vector m)
                                         (col : nat),
  reduce_col (col_wedge T v col) col = T.

Lemma row_wedge_reduce_row_same : {n m} (T : Matrix n m) (v : Matrix 1 n)
                                         (row : nat),
  reduce_row (row_wedge T v row) row = T.

Lemma col_add_many_reduce_row : {n m} (T : Matrix (S n) m) (v : Vector m) (col row : nat),
  col_add_many col v (reduce_row T row) = reduce_row (col_add_many col v T) row.

Lemma col_swap_same : {n m : nat} (S : Matrix n m) (x : nat),
  col_swap S x x = S.

Lemma row_swap_same : {n m : nat} (S : Matrix n m) (x : nat),
  row_swap S x x = S.

Lemma col_swap_diff_order : {n m : nat} (S : Matrix n m) (x y : nat),
  col_swap S x y = col_swap S y x.

Lemma row_swap_diff_order : {n m : nat} (S : Matrix n m) (x y : nat),
  row_swap S x y = row_swap S y x.

Lemma col_swap_inv : {n m : nat} (S : Matrix n m) (x y : nat),
  S = col_swap (col_swap S x y) x y.

Lemma row_swap_inv : {n m : nat} (S : Matrix n m) (x y : nat),
  S = row_swap (row_swap S x y) x y.

Lemma col_swap_get_vec : {n m : nat} (S : Matrix n m) (x y : nat),
  get_vec y S = get_vec x (col_swap S x y).

Lemma col_swap_three : {n m} (T : Matrix n m) (x y z : nat),
  x z y z col_swap T x z = col_swap (col_swap (col_swap T x y) y z) x y.

Lemma reduce_row_reduce_col : {n m} (A : Matrix (S n) (S m)) (i j : nat),
  reduce_col (reduce_row A i) j = reduce_row (reduce_col A j) i.
Lemma reduce_col_swap_01 : {n} (A : Square (S (S n))),
  reduce_col (reduce_col (col_swap A 0 1) 0) 0 = reduce_col (reduce_col A 0) 0.

Lemma reduce_reduce_0 : {n} (A : Square (S (S n))) (x y : nat),
  x y
  (reduce (reduce A x 0) y 0) = (reduce (reduce A (S y) 0) x 0).

Lemma col_add_split : {n} (A : Square (S n)) (i : nat) (c : C),
  col_add A 0 i c = col_wedge (reduce_col A 0) (get_vec 0 A .+ c.* get_vec i A) 0.

Lemma col_swap_col_add_Si : {n} (A : Square n) (i j : nat) (c : C),
  i 0 i j col_swap (col_add (col_swap A j 0) 0 i c) j 0 = col_add A j i c.

Lemma col_swap_col_add_0 : {n} (A : Square n) (j : nat) (c : C),
  j 0 col_swap (col_add (col_swap A j 0) 0 j c) j 0 = col_add A j 0 c.

Lemma col_swap_end_reduce_col_hit : {n m : nat} (T : Matrix n (S (S m))) (i : nat),
  i m col_swap (reduce_col T i) m i = reduce_col (col_swap T (S m) (S i)) i.

Lemma col_swap_reduce_row : {n m : nat} (S : Matrix (S n) m) (x y row : nat),
  col_swap (reduce_row S row) x y = reduce_row (col_swap S x y) row.

Lemma col_scale_inv : {n m : nat} (S : Matrix n m) (x : nat) (a : C),
  a C0 S = col_scale (col_scale S x a) x (/ a).

Lemma row_scale_inv : {n m : nat} (S : Matrix n m) (x : nat) (a : C),
  a C0 S = row_scale (row_scale S x a) x (/ a).

Lemma col_add_double : {n m : nat} (S : Matrix n m) (x : nat) (a : C),
  col_add S x x a = col_scale S x (C1 + a).

Lemma row_add_double : {n m : nat} (S : Matrix n m) (x : nat) (a : C),
  row_add S x x a = row_scale S x (C1 + a).

Lemma col_add_swap : {n m : nat} (S : Matrix n m) (x y : nat) (a : C),
  col_swap (col_add S x y a) x y = col_add (col_swap S x y) y x a.

Lemma row_add_swap : {n m : nat} (S : Matrix n m) (x y : nat) (a : C),
  row_swap (row_add S x y a) x y = row_add (row_swap S x y) y x a.

Lemma col_add_inv : {n m : nat} (S : Matrix n m) (x y : nat) (a : C),
  x y S = col_add (col_add S x y a) x y (-a).

Lemma row_add_inv : {n m : nat} (S : Matrix n m) (x y : nat) (a : C),
  x y S = row_add (row_add S x y a) x y (-a).

Lemma mat_equiv_make_WF : {n m} (T : Matrix n m),
  T == make_WF T.

Lemma eq_make_WF : {n m} (T : Matrix n m),
  WF_Matrix T T = make_WF T.

Lemma col_swap_make_WF : {n m} (T : Matrix n m) (x y : nat),
  x < m y < m col_swap (make_WF T) x y = make_WF (col_swap T x y).

Lemma col_scale_make_WF : {n m} (T : Matrix n m) (x : nat) (c : C),
  col_scale (make_WF T) x c = make_WF (col_scale T x c).

Lemma col_add_make_WF : {n m} (T : Matrix n m) (x y : nat) (c : C),
  x < m y < m col_add (make_WF T) x y c = make_WF (col_add T x y c).

Lemma Mmult_make_WF : {n m o} (A : Matrix n m) (B : Matrix m o),
  make_WF A × make_WF B = make_WF (A × B).

Lemma gen_new_vec_0 : {n m} (T : Matrix n m) (as' : Vector m),
  as' == Zero gen_new_vec n m T as' = Zero.

Lemma gen_new_row_0 : {n m} (T : Matrix n m) (as' : Matrix 1 n),
  as' == Zero gen_new_row n m T as' = Zero.

Lemma col_add_many_0 : {n m} (col : nat) (T : Matrix n m) (as' : Vector m),
  as' == Zero T = col_add_many col as' T.

Lemma row_add_many_0 : {n m} (row : nat) (T : Matrix n m) (as' : Matrix 1 n),
  as' == Zero T = row_add_many row as' T.

Lemma gen_new_vec_mat_equiv : {n m} (T : Matrix n m) (as' bs : Vector m),
  as' == bs gen_new_vec n m T as' = gen_new_vec n m T bs.

Lemma gen_new_row_mat_equiv : {n m} (T : Matrix n m) (as' bs : Matrix 1 n),
  as' == bs gen_new_row n m T as' = gen_new_row n m T bs.

Lemma col_add_many_mat_equiv : {n m} (col : nat) (T : Matrix n m) (as' bs : Vector m),
  as' == bs col_add_many col as' T = col_add_many col bs T.

Lemma row_add_many_mat_equiv : {n m} (row : nat) (T : Matrix n m) (as' bs : Matrix 1 n),
  as' == bs row_add_many row as' T = row_add_many row bs T.

Lemma col_add_each_0 : {n m} (col : nat) (T : Matrix n m) (v : Matrix 1 m),
  v = Zero T = col_add_each col v T.

Lemma row_add_each_0 : {n m} (row : nat) (T : Matrix n m) (v : Vector n),
  v = Zero T = row_add_each row v T.

Lemma col_add_many_col_add : {n m} (col e : nat) (T : Matrix n m) (as' : Vector m),
  col e e < m as' col 0 = C0
  col_add_many col as' T =
  col_add (col_add_many col (make_row_zero e as') T) col e (as' e 0).

Lemma col_add_many_cancel : {n m} (T : Matrix n (S m)) (as' : Vector (S m)) (col : nat),
  col < (S m) as' col 0 = C0
  (reduce_col T col) × (reduce_row as' col) = -C1 .* (get_vec col T)
  ( i : nat, (col_add_many col as' T) i col = C0).

Lemma col_add_many_inv : {n m} (S : Matrix n m) (col : nat) (as' : Vector m),
  as' col 0 = C0 S = col_add_many col (-C1 .* as') (col_add_many col as' S).

Lemma col_add_each_col_add : {n m} (col e : nat) (S : Matrix n m) (as' : Matrix 1 m),
  col e ( x, as' x col = C0)
              col_add_each col as' S =
              col_add (col_add_each col (make_col_zero e as') S) e col (as' 0 e).

Lemma row_add_each_row_add : {n m} (row e : nat) (S : Matrix n m) (as' : Vector n),
  row e ( y, as' row y = C0)
              row_add_each row as' S =
              row_add (row_add_each row (make_row_zero e as') S) e row (as' e 0).

Lemma col_add_each_inv : {n m} (col : nat) (as' : Matrix 1 m) (T : Matrix n m),
  T = col_add_each col (make_col_zero col (-C1 .* as'))
                   (col_add_each col (make_col_zero col as') T).

Lemma row_add_each_inv : {n m} (row : nat) (as' : Vector n) (T : Matrix n m),
  T = row_add_each row (make_row_zero row (-C1 .* as'))
                   (row_add_each row (make_row_zero row as') T).

Lemma get_vec_transpose : {n m} (A : Matrix n m) (i : nat),
  (get_vec i A)⊤ = get_row i (A).

Lemma get_row_transpose : {n m} (A : Matrix n m) (i : nat),
  (get_row i A)⊤ = get_vec i (A).

Lemma col_swap_transpose : {n m} (A : Matrix n m) (x y : nat),
  (col_swap A x y)⊤ = row_swap (A) x y.

Lemma row_swap_transpose : {n m} (A : Matrix n m) (x y : nat),
  (row_swap A x y)⊤ = col_swap (A) x y.

Lemma col_scale_transpose : {n m} (A : Matrix n m) (x : nat) (a : C),
  (col_scale A x a)⊤ = row_scale (A) x a.

Lemma row_scale_transpose : {n m} (A : Matrix n m) (x : nat) (a : C),
  (row_scale A x a)⊤ = col_scale (A) x a.

Lemma col_add_transpose : {n m} (A : Matrix n m) (col to_add : nat) (a : C),
  (col_add A col to_add a)⊤ = row_add (A) col to_add a.

Lemma row_add_transpose : {n m} (A : Matrix n m) (row to_add : nat) (a : C),
  (row_add A row to_add a)⊤ = col_add (A) row to_add a.

Lemma col_add_many_transpose : {n m} (A : Matrix n m) (col : nat) (as' : Vector m),
  (col_add_many col as' A)⊤ = row_add_many col (as') (A).

Lemma row_add_many_transpose : {n m} (A : Matrix n m) (row : nat) (as' : Matrix 1 n),
  (row_add_many row as' A)⊤ = col_add_many row (as') (A).

Lemma col_add_each_transpose : {n m} (A : Matrix n m) (col : nat) (as' : Matrix 1 m),
  (col_add_each col as' A)⊤ = row_add_each col (as') (A).

Lemma row_add_each_transpose : {n m} (A : Matrix n m) (row : nat) (as' : Vector n),
  (row_add_each row as' A)⊤ = col_add_each row (as') (A).

the idea is to show that col operations correspond to multiplication by special matrices. Thus, we show that the col ops all satisfy various multiplication rules
Lemma swap_preserves_mul_lt : {n m o} (A : Matrix n m) (B : Matrix m o) (x y : nat),
  x < y x < m y < m A × B = (col_swap A x y) × (row_swap B x y).

Lemma swap_preserves_mul : {n m o} (A : Matrix n m) (B : Matrix m o) (x y : nat),
  x < m y < m A × B = (col_swap A x y) × (row_swap B x y).

Lemma scale_preserves_mul : {n m o} (A : Matrix n m) (B : Matrix m o) (x : nat) (a : C),
  A × (row_scale B x a) = (col_scale A x a) × B.

Lemma add_preserves_mul_lt : {n m o} (A : Matrix n m) (B : Matrix m o)
                                                (x y : nat) (a : C),
   x < y x < m y < m A × (row_add B y x a) = (col_add A x y a) × B.

Lemma add_preserves_mul : {n m o} (A : Matrix n m) (B : Matrix m o)
                                             (x y : nat) (a : C),
   x < m y < m A × (row_add B y x a) = (col_add A x y a) × B.

Definition skip_count (skip i : nat) : nat :=
  if (i <? skip) then i else S i.

Lemma skip_count_le : (skip i : nat),
  i skip_count skip i.

Lemma skip_count_not_skip : (skip i : nat),
  skip skip_count skip i.

Lemma skip_count_mono : (skip i1 i2 : nat),
  i1 < i2 skip_count skip i1 < skip_count skip i2.

Lemma cam_ca_switch : {n m} (T : Matrix n m) (as' : Vector m) (col to_add : nat) (c : C),
  as' col 0 = C0 to_add col
  col_add (col_add_many col as' T) col to_add c =
  col_add_many col as' (col_add T col to_add c).

Lemma col_add_many_preserves_mul_some : (n m o e col : nat)
                                               (A : Matrix n m) (B : Matrix m o) (v : Vector m),
  WF_Matrix v (skip_count col e) < m col < m
  ( i : nat, (skip_count col e) < i v i 0 = C0) v col 0 = C0
  A × (row_add_each col v B) = (col_add_many col v A) × B.

Lemma col_add_many_preserves_mul: (n m o col : nat)
                                               (A : Matrix n m) (B : Matrix m o) (v : Vector m),
  WF_Matrix v col < m v col 0 = C0
  A × (row_add_each col v B) = (col_add_many col v A) × B.

Lemma col_add_each_preserves_mul: (n m o col : nat) (A : Matrix n m)
                                                         (B : Matrix m o) (v : Matrix 1 m),
  WF_Matrix v col < m v 0 col = C0
  A × (row_add_many col v B) = (col_add_each col v A) × B.

Lemma col_swap_mult_r : {n} (A : Square n) (x y : nat),
  x < n y < n WF_Matrix A
  col_swap A x y = A × (row_swap (I n) x y).

Lemma col_scale_mult_r : {n} (A : Square n) (x : nat) (a : C),
  WF_Matrix A
  col_scale A x a = A × (row_scale (I n) x a).

Lemma col_add_mult_r : {n} (A : Square n) (x y : nat) (a : C),
  x < n y < n WF_Matrix A
  col_add A x y a = A × (row_add (I n) y x a).

Lemma col_add_many_mult_r : {n} (A : Square n) (v : Vector n) (col : nat),
  WF_Matrix A WF_Matrix v col < n v col 0 = C0
  col_add_many col v A = A × (row_add_each col v (I n)).

Lemma col_add_each_mult_r : {n} (A : Square n) (v : Matrix 1 n) (col : nat),
  WF_Matrix A WF_Matrix v col < n v 0 col = C0
  col_add_each col v A = A × (row_add_many col v (I n)).

Lemma col_row_swap_invr_I : (n x y : nat),
  x < n y < n col_swap (I n) x y = row_swap (I n) x y.

Lemma col_row_scale_invr_I : (n x : nat) (c : C),
  col_scale (I n) x c = row_scale (I n) x c.

Lemma col_row_add_invr_I : (n x y : nat) (c : C),
  x < n y < n col_add (I n) x y c = row_add (I n) y x c.

Lemma row_each_col_many_invr_I : (n col : nat) (v : Vector n),
  WF_Matrix v col < n v col 0 = C0
  row_add_each col v (I n) = col_add_many col v (I n).

Lemma row_many_col_each_invr_I : (n col : nat) (v : Matrix 1 n),
  WF_Matrix v col < n v 0 col = C0
  row_add_many col v (I n) = col_add_each col v (I n).

Lemma reduce_append_split : {n m} (T : Matrix n (S m)),
  WF_Matrix T T = col_append (reduce_col T m) (get_vec m T).

Lemma smash_zero : {n m} (T : Matrix n m) (i : nat),
  WF_Matrix T smash T (@Zero n i) = T.

Lemma smash_assoc : {n m1 m2 m3}
                           (T1 : Matrix n m1) (T2 : Matrix n m2) (T3 : Matrix n m3),
  smash (smash T1 T2) T3 = smash T1 (smash T2 T3).

Lemma smash_append : {n m} (T : Matrix n m) (v : Vector n),
  WF_Matrix T WF_Matrix v
  col_append T v = smash T v.

Lemma smash_reduce : {n m1 m2} (T1 : Matrix n m1) (T2 : Matrix n (S m2)),
  reduce_col (smash T1 T2) (m1 + m2) = smash T1 (reduce_col T2 m2).

Lemma split_col : {n m} (T : Matrix n (S m)),
  T = smash (get_vec 0 T) (reduce_col T 0).

Some more lemmas with these new concepts


Lemma vec_equiv_dec : {n : nat} (A B : Vector n),
    { A == B } + { ¬ (A == B) }.

Lemma mat_equiv_dec : {n m : nat} (A B : Matrix n m),
    { A == B } + { ¬ (A == B) }.

Lemma last_zero_simplification : {n : nat} (v : Vector (S n)),
  WF_Matrix v v n 0 = C0 v = reduce_vecn v.

Lemma zero_reduce : {n : nat} (v : Vector (S n)) (x : nat),
  WF_Matrix v (v = Zero (reduce_row v x) = Zero v x 0 = C0).

Lemma nonzero_vec_nonzero_elem : {n} (v : Vector n),
  WF_Matrix v v Zero x, v x 0 C0.

Local Close Scope nat_scope.

Lemma inner_product_scale_l : {n} (u v : Vector n) (c : C),
  c .* u, v = c^* × u,v.

Lemma inner_product_scale_r : {n} (u v : Vector n) (c : C),
  u, c .* v = c × u,v.

Lemma inner_product_plus_l : {n} (u v w : Vector n),
  u .+ v, w = u, w + v, w.

Lemma inner_product_plus_r : {n} (u v w : Vector n),
  u, v .+ w = u, v + u, w.

Lemma inner_product_big_sum_l : {n} (u : Vector n) (f : nat Vector n) (k : nat),
  big_sum f k, u = big_sum (fun if i, u) k.

Lemma inner_product_big_sum_r : {n} (u : Vector n) (f : nat Vector n) (k : nat),
  u, big_sum f k = big_sum (fun iu, f i) k.

Lemma inner_product_conj_sym : {n} (u v : Vector n),
  u, v = v, u^*.

Lemma inner_product_mafe_WF_l : {n} (u v : Vector n),
  u, v = make_WF u, v.

Lemma inner_product_mafe_WF_r : {n} (u v : Vector n),
  u, v = u, make_WF v.

Definition norm {n} (ψ : Vector n) : R :=
  sqrt (fst ψ,ψ).

Definition normalize {n} (ψ : Vector n) :=
  / (norm ψ) .* ψ.

Lemma norm_scale : {n} c (v : Vector n), norm (c .* v) = ((Cmod c) × norm v)%R.

Lemma normalized_norm_1 : {n} (v : Vector n),
  norm v 0 norm (normalize v) = 1.

Lemma rewrite_norm : {d} (ψ : Vector d),
    fst ψ,ψ = big_sum (fun iCmod (ψ i O) ^ 2)%R d.

Local Open Scope nat_scope.

Lemma norm_real : {n} (v : Vector n), snd v,v = 0%R.

Lemma inner_product_ge_0 : {d} (ψ : Vector d),
  (0 fst ψ,ψ)%R.

Lemma norm_ge_0 : {d} (ψ : Vector d),
  (0 norm ψ)%R.

Lemma norm_squared : {d} (ψ : Vector d),
  ((norm ψ) ^2)%R = fst ψ, ψ .

Lemma inner_product_zero_iff_zero : {n} (v : Vector n),
  WF_Matrix v (v,v = C0 v = Zero).

Lemma norm_zero_iff_zero : {n} (v : Vector n),
  WF_Matrix v (norm v = 0%R v = Zero).

Local Close Scope nat_scope.

Lemma CS_key_lemma : {n} (u v : Vector n),
  fst (⟨v,v .* u .+ -1 × v,u .* v), (⟨v,v .* u .+ -1 × v,u .* v) =
    ((fst v,v) × ((fst v,v)* (fst u,u) - (Cmod u,v)^2 ))%R.

Lemma real_ge_0_aux : (a b c : R),
  0 a 0 < b (a = b × c)%R
  0 c.

Lemma Cauchy_Schwartz_ver1 : {n} (u v : Vector n),
  (Cmod u,v)^2 (fst u,u) × (fst v,v).

Lemma Cauchy_Schwartz_ver2 : {n} (u v : Vector n),
  (Cmod u,v) norm u × norm v.

Lemma Cplx_Cauchy_vector :
   n (u v : Vector n),
    ((big_sum (fun iCmod (u i O) ^ 2) n) × (big_sum (fun iCmod (v i O) ^ 2) n)
     Cmod (big_sum (fun i ⇒ ((u i O)^* × (v i O))%C) n) ^ 2)%R.

Local Open Scope nat_scope.

Lemma Cplx_Cauchy :
   n (u v : nat C),
    ((big_sum (fun iCmod (u i) ^ 2) n) × (big_sum (fun iCmod (v i) ^ 2) n) Cmod (big_sum (fun i ⇒ ((u i)^* × (v i))%C) n) ^ 2)%R.

Tactics




Fixpoint vec_to_list' {nmax : nat} (n : nat) (v : Vector nmax) :=
  match n with
  | Onil
  | S n'v (nmax - n)%nat O :: vec_to_list' n' v
  end.
Definition vec_to_list {n : nat} (v : Vector n) := vec_to_list' n v.

Lemma vec_to_list'_length : m n (v : Vector n), length (vec_to_list' m v) = m.

Lemma vec_to_list_length : n (v : Vector n), length (vec_to_list v) = n.

Lemma nth_vec_to_list' : {m n} (v : Vector n) x,
  (m n)%nat (x < m)%nat nth x (vec_to_list' m v) C0 = v (n - m + x)%nat O.

Lemma nth_vec_to_list : n (v : Vector n) x,
  (x < n)%nat nth x (vec_to_list v) C0 = v x O.

Restoring Matrix Dimensions
Restoring Matrix dimensions
Ltac is_nat n := match type of n with natidtac end.

Ltac is_nat_equality :=
  match goal with
  | |- ?A = ?Bis_nat A
  end.

Ltac unify_matrix_dims tac :=
  try reflexivity;
  repeat (apply f_equal_gen; try reflexivity;
          try (is_nat_equality; tac)).

Ltac restore_dims_rec A :=
   match A with

  | ?A × I _let A' := restore_dims_rec A in
                        match type of A' with
                        | Matrix ?m' ?n'constr:(@Mmult m' n' n' A' (I n'))
                        end
  | I _ × ?Blet B' := restore_dims_rec B in
                        match type of B' with
                        | Matrix ?n' ?o'constr:(@Mmult n' n' o' (I n') B')
                        end
  | ?A × @Zero ?n ?nlet A' := restore_dims_rec A in
                        match type of A' with
                        | Matrix ?m' ?n'constr:(@Mmult m' n' n' A' (@Zero n' n'))
                        end
  | @Zero ?n ?n × ?Blet B' := restore_dims_rec B in
                        match type of B' with
                        | Matrix ?n' ?o'constr:(@Mmult n' n' o' (@Zero n' n') B')
                        end
  | ?A × @Zero ?n ?olet A' := restore_dims_rec A in
                        match type of A' with
                        | Matrix ?m' ?n'constr:(@Mmult m' n' o A' (@Zero n' o))
                        end
  | @Zero ?m ?n × ?Blet B' := restore_dims_rec B in
                        match type of B' with
                        | Matrix ?n' ?o'constr:(@Mmult n' n' o' (@Zero m n') B')
                        end
  | ?A .+ @Zero ?m ?nlet A' := restore_dims_rec A in
                        match type of A' with
                        | Matrix ?m' ?n'constr:(@Mplus m' n' A' (@Zero m' n'))
                        end
  | @Zero ?m ?n .+ ?Blet B' := restore_dims_rec B in
                        match type of B' with
                        | Matrix ?m' ?n'constr:(@Mplus m' n' (@Zero m' n') B')
                        end

  | ?A = ?Blet A' := restore_dims_rec A in
                let B' := restore_dims_rec B in
                match type of A' with
                | Matrix ?m' ?n'constr:(@eq (Matrix m' n') A' B')
                  end
  | ?A × ?Blet A' := restore_dims_rec A in
                let B' := restore_dims_rec B in
                match type of A' with
                | Matrix ?m' ?n'
                  match type of B' with
                  | Matrix ?n'' ?o'constr:(@Mmult m' n' o' A' B')
                  end
                end
  | ?A ?Blet A' := restore_dims_rec A in
                let B' := restore_dims_rec B in
                match type of A' with
                | Matrix ?m' ?n'
                  match type of B' with
                  | Matrix ?o' ?p'constr:(@kron m' n' o' p' A' B')
                  end
                end
  | ?A let A' := restore_dims_rec A in
                match type of A' with
                | Matrix ?m' ?n'constr:(@adjoint m' n' A')
                end
  | ?A .+ ?Blet A' := restore_dims_rec A in
               let B' := restore_dims_rec B in
               match type of A' with
               | Matrix ?m' ?n'
                 match type of B' with
                 | Matrix ?m'' ?n''constr:(@Mplus m' n' A' B')
                 end
               end
  | ?c .* ?Alet A' := restore_dims_rec A in
               match type of A' with
               | Matrix ?m' ?n'constr:(@scale m' n' c A')
               end
  | ?n ?Alet A' := restore_dims_rec A in
               match type of A' with
               | Matrix ?m' ?n'constr:(@kron_n n m' n' A')
               end
  
  | ?P ?m ?n ?Amatch type of P with
                  | nat nat Matrix _ _ Prop
                    let A' := restore_dims_rec A in
                    match type of A' with
                    | Matrix ?m' ?n'constr:(P m' n' A')
                    end
                  end
  | ?P ?n ?Amatch type of P with
               | nat Matrix _ _ Prop
                 let A' := restore_dims_rec A in
                 match type of A' with
                 | Matrix ?m' ?n'constr:(P m' A')
                 end
               end
  
  | ?f ?Alet f' := restore_dims_rec f in
               let A' := restore_dims_rec A in
               constr:(f' A')
  
  | ?AA
   end.

Ltac restore_dims tac :=
  match goal with
  | |- ?Alet A' := restore_dims_rec A in
                replace A with A' by unify_matrix_dims tac
  end.

Tactic Notation "restore_dims" tactic(tac) := restore_dims tac.

Tactic Notation "restore_dims" := restore_dims (repeat rewrite Nat.pow_1_l; try ring; unify_pows_two; simpl; lia).


Lemma kron_n_m_split {o p} : n m (A : Matrix o p),
  WF_Matrix A (n + m) A = n A m A.

Matrix Simplification


Hint Rewrite @kron_1_l @kron_1_r @Mmult_1_l @Mmult_1_r @Mscale_1_l
     @id_adjoint_eq @id_transpose_eq using (auto 100 with wf_db) : M_db_light.
Hint Rewrite @kron_0_l @kron_0_r @Mmult_0_l @Mmult_0_r @Mplus_0_l @Mplus_0_r
     @Mscale_0_l @Mscale_0_r @zero_adjoint_eq @zero_transpose_eq using (auto 100 with wf_db) : M_db_light.

Ltac Msimpl_light := try restore_dims; autorewrite with M_db_light.

Hint Rewrite @Mmult_adjoint @Mplus_adjoint @kron_adjoint @kron_mixed_product
     @adjoint_involutive using (auto 100 with wf_db) : M_db.

Ltac Msimpl := try restore_dims; autorewrite with M_db_light M_db.

Distribute addition to the outside of matrix expressions.

Ltac distribute_plus :=
  repeat match goal with
  | |- context [?a × (?b .+ ?c)] ⇒ rewrite (Mmult_plus_distr_l _ _ _ a b c)
  | |- context [(?a .+ ?b) × ?c] ⇒ rewrite (Mmult_plus_distr_r _ _ _ a b c)
  | |- context [?a (?b .+ ?c)] ⇒ rewrite (kron_plus_distr_l _ _ _ _ a b c)
  | |- context [(?a .+ ?b) ?c] ⇒ rewrite (kron_plus_distr_r _ _ _ _ a b c)
  end.

Distribute scaling to the outside of matrix expressions

Ltac distribute_scale :=
  repeat
   match goal with
   | |- context [ (?c .* ?A) × ?B ] ⇒ rewrite (Mscale_mult_dist_l _ _ _ c A B)
   | |- context [ ?A × (?c .* ?B) ] ⇒ rewrite (Mscale_mult_dist_r _ _ _ c A B)
   | |- context [ (?c .* ?A) ?B ] ⇒ rewrite (Mscale_kron_dist_l _ _ _ _ c A B)
   | |- context [ ?A (?c .* ?B) ] ⇒ rewrite (Mscale_kron_dist_r _ _ _ _ c A B)
   | |- context [ ?c .* (?c' .* ?A) ] ⇒ rewrite (Mscale_assoc _ _ c c' A)
   end.

Ltac distribute_adjoint :=
  repeat match goal with
  | |- context [(?c .* ?A)†] ⇒ rewrite (Mscale_adj _ _ c A)
  | |- context [(?A .+ ?B)†] ⇒ rewrite (Mplus_adjoint _ _ A B)
  | |- context [(?A × ?B)†] ⇒ rewrite (Mmult_adjoint A B)
  | |- context [(?A ?B)†] ⇒ rewrite (kron_adjoint A B)
  end.

Tactics for solving computational matrix equalities

Ltac mk_evar t T := match goal with _evar (t : T) end.

Ltac evar_list n :=
  match n with
  | Oconstr:(@nil C)
  | S ?n'let e := fresh "e" in
            let none := mk_evar e C in
            let ls := evar_list n' in
            constr:(e :: ls)
            
  end.

Ltac evar_list_2d m n :=
  match m with
  | Oconstr:(@nil (list C))
  | S ?m'let ls := evar_list n in
            let ls2d := evar_list_2d m' n in
            constr:(ls :: ls2d)
  end.

Ltac evar_matrix m n := let ls2d := (evar_list_2d m n)
                        in constr:(list2D_to_matrix ls2d).

Ltac tac_lt m n :=
  match n with
  | S ?n'match m with
            | Oidtac
            | S ?m'tac_lt m' n'
            end
  end.

Ltac assoc_least :=
  repeat (simpl; match goal with
  | [|- context[@Mmult ?m ?o ?p (@Mmult ?m ?n ?o ?A ?B) ?C]] ⇒ tac_lt p o; tac_lt p m;
       let H := fresh "H" in
       specialize (Mmult_assoc A B C) as H; simpl in H; rewrite H; clear H
  | [|- context[@Mmult ?m ?o ?p (@Mmult ?m ?n ?o ?A ?B) ?C]] ⇒ tac_lt n o; tac_lt n m;
       let H := fresh "H" in
       specialize (Mmult_assoc A B C) as H; simpl in H; rewrite H; clear H
  | [|- context[@Mmult ?m ?n ?p ?A (@Mmult ?n ?o ?p ?B ?C)]] ⇒ tac_lt m n; tac_lt m p;
       let H := fresh "H" in
       specialize (Mmult_assoc A B C) as H; simpl in H; rewrite <- H; clear H
  | [|- context[@Mmult ?m ?n ?p ?A (@Mmult ?n ?o ?p ?B ?C)]] ⇒ tac_lt o n; tac_lt o p;
       let H := fresh "H" in
       specialize (Mmult_assoc A B C) as H; simpl in H; rewrite <- H; clear H
  end).

Ltac solve_out_of_bounds :=
  repeat match goal with
  | [H : WF_Matrix ?M |- context[?M ?a ?b] ] ⇒
      rewrite (H a b) by (left; simpl; lia)
  | [H : WF_Matrix ?M |- context[?M ?a ?b] ] ⇒
      rewrite (H a b) by (right; simpl; lia)
  end;
  autorewrite with C_db; auto.

Lemma divmod_eq : x y n z,
  fst (Nat.divmod x y n z) = (n + fst (Nat.divmod x y 0 z))%nat.

Lemma divmod_S : x y n z,
  fst (Nat.divmod x y (S n) z) = (S n + fst (Nat.divmod x y 0 z))%nat.

Ltac destruct_m_1' :=
  match goal with
  | [ |- context[match ?x with
                 | 0 ⇒ _
                 | S __
                 end] ] ⇒ is_var x; destruct x
  | [ |- context[match fst (Nat.divmod ?x _ _ _) with
                 | 0 ⇒ _
                 | S __
                 end] ] ⇒ is_var x; destruct x
  end.

Lemma divmod_0q0 : x q, fst (Nat.divmod x 0 q 0) = (x + q)%nat.

Lemma divmod_0 : x, fst (Nat.divmod x 0 0 0) = x.

Ltac destruct_m_eq' := repeat
  (progress (try destruct_m_1'; try rewrite divmod_0; try rewrite divmod_S; simpl)).


Ltac crunch_matrix :=
                    match goal with
                      | [|- ?G ] ⇒ idtac "Crunching:" G
                      end;
                      repeat match goal with
                             | [ c : C |- _ ] ⇒ cbv [c]; clear c
                             end;
                      simpl;
                      unfold list2D_to_matrix;
                      autounfold with U_db;
                      prep_matrix_equality;
                      simpl;
                      destruct_m_eq';
                      simpl;
                      Csimpl;
                      try reflexivity;
                      try solve_out_of_bounds.

Ltac compound M :=
  match M with
  | ?A × ?Bidtac
  | ?A .+ ?Bidtac
  | ?A compound A
  end.

Ltac reduce_aux M :=
  match M with
  | ?A .+ ?Bcompound A; reduce_aux A
  | ?A .+ ?Bcompound B; reduce_aux B
  | ?A × ?Bcompound A; reduce_aux A
  | ?A × ?Bcompound B; reduce_aux B
  | @Mmult ?m ?n ?o ?A ?Blet M' := evar_matrix m o in
                                 replace M with M';
                                 [| crunch_matrix ]
  | @Mplus ?m ?n ?A ?Blet M' := evar_matrix m n in
                                 replace M with M';
                                 [| crunch_matrix ]
  end.

Ltac reduce_matrix := match goal with
                       | [ |- ?M = _] ⇒ reduce_aux M
                       | [ |- _ = ?M] ⇒ reduce_aux M
                       end;
                       repeat match goal with
                              | [ |- context[?c :: _ ]] ⇒ cbv [c]; clear c
                              end.

Ltac reduce_matrices := assoc_least;
                        match goal with
                        | [ |- context[?M]] ⇒ reduce_aux M
                        end;
                        repeat match goal with
                               | [ |- context[?c :: _ ]] ⇒ cbv [c]; clear c
                               end.

Ltac solve_matrix := assoc_least;
                     repeat reduce_matrix; try crunch_matrix;
                     
                     unfold Nat.ltb; simpl; try rewrite andb_false_r;
                     
                     autorewrite with C_db; try lca.

Gridify
Gridify: Turns an matrix expression into a normal form with plus on the outside, then tensor, then matrix multiplication. Eg: ((..×..×..)⊗(..×..×..)⊗(..×..×..)) .+ ((..×..)⊗(..×..))

Lemma repad_lemma1_l : (a b d : nat),
  a < b d = (b - a - 1) b = a + 1 + d.

Lemma repad_lemma1_r : (a b d : nat),
  a < b d = (b - a - 1) b = d + 1 + a.

Lemma repad_lemma2 : (a b d : nat),
  a b d = (b - a) b = a + d.

Lemma le_ex_diff_l : a b, a b d, b = d + a.

Lemma le_ex_diff_r : a b, a b d, b = a + d.

Lemma lt_ex_diff_l : a b, a < b d, b = d + 1 + a.

Lemma lt_ex_diff_r : a b, a < b d, b = a + 1 + d.

Ltac remember_differences :=
  repeat match goal with
  | H : ?a < ?b |- context[?b - ?a - 1] ⇒
    let d := fresh "d" in
    let R := fresh "R" in
    remember (b - a - 1) as d eqn:R ;
    apply (repad_lemma1_l a b d) in H; trivial;
    clear R;
    try rewrite H in *;
    try clear b H
  | H:?a ?b |- context [ ?b - ?a ] ⇒
    let d := fresh "d" in
    let R := fresh "R" in
    remember (b - a) as d eqn:R ;
    apply (repad_lemma2 a b d) in H; trivial;
    clear R;
    try rewrite H in *;
    try clear b H
  end.

Ltac get_dimensions M :=
  match M with
  | ?A ?Blet a := get_dimensions A in
               let b := get_dimensions B in
               constr:(a + b)
  | ?A .+ ?Bget_dimensions A
  | _match type of M with
               | Matrix 2 2 ⇒ constr:(1)
               | Matrix 4 4 ⇒ constr:(2)
               | Matrix (2^?a) (2^?a) ⇒ constr:(a)

               end
  end.


Ltac hypothesize_dims :=
  match goal with
  | |- context[?A × ?B] ⇒ let a := get_dimensions A in
                         let b := get_dimensions B in
                         assert(a = b) by lia
  end.

Ltac fill_differences :=
  repeat match goal with
  | R : _ < _ |- _let d := fresh "d" in
                              destruct (lt_ex_diff_r _ _ R);
                              clear R; subst
  | H : _ = _ |- _rewrite <- plus_assoc in H
  | H : ?a + _ = ?a + _ |- _apply Nat.add_cancel_l in H; subst
  | H : ?a + _ = ?b + _ |- _destruct (lt_eq_lt_dec a b) as [[?|?]|?]; subst
  end; try lia.

Ltac repad :=
  
  bdestruct_all; Msimpl_light; try reflexivity;
  
  remember_differences;
  
  try hypothesize_dims; clear_dups;
  
  fill_differences.

Ltac gridify :=
  
  bdestruct_all; Msimpl_light; try reflexivity;
  
  remember_differences;
  
  try hypothesize_dims; clear_dups;
  
  fill_differences;
  
  restore_dims; distribute_plus;
  repeat rewrite Nat.pow_add_r;
  repeat rewrite <- id_kron; simpl;
  repeat rewrite mult_assoc;
  restore_dims; repeat rewrite <- kron_assoc by auto 100 with wf_db;
  restore_dims; repeat rewrite kron_mixed_product;
  
  Msimpl_light.

Tactics to show implicit arguments

Definition kron' := @kron.
Lemma kron_shadow : @kron = kron'.

Definition Mmult' := @Mmult.
Lemma Mmult_shadow : @Mmult = Mmult'.

Ltac show_dimensions := try rewrite kron_shadow in *;
                        try rewrite Mmult_shadow in ×.
Ltac hide_dimensions := try rewrite <- kron_shadow in *;
                        try rewrite <- Mmult_shadow in ×.