QuantumLib.Quantum


In this file, we define specific objects/concepts specific to quantum computing and we prove lemmas about thems.

Require Import Psatz.
Require Import Reals.
Require Export VecSet.


Quantum basis states


Definition qubit0 : Vector 2 :=
  fun x ymatch x, y with
          | 0, 0 ⇒ C1
          | 1, 0 ⇒ C0
          | _, _C0
          end.
Definition qubit1 : Vector 2 :=
  fun x ymatch x, y with
          | 0, 0 ⇒ C0
          | 1, 0 ⇒ C1
          | _, _C0
          end.

Notation "∣0⟩" := qubit0.
Notation "∣1⟩" := qubit1.
Notation "⟨0∣" := qubit0.
Notation "⟨1∣" := qubit1.
Notation "∣0⟩⟨0∣" := (∣0⟩×⟨0∣).
Notation "∣1⟩⟨1∣" := (∣1⟩×⟨1∣).
Notation "∣1⟩⟨0∣" := (∣1⟩×⟨0∣).
Notation "∣0⟩⟨1∣" := (∣0⟩×⟨1∣).

Definition bra (x : nat) : Matrix 1 2 := if x =? 0 then ⟨0∣ else ⟨1∣.
Definition ket (x : nat) : Matrix 2 1 := if x =? 0 then ∣0⟩ else ∣1⟩.

Notation "'∣' x '⟩'" := (ket x).
Notation "'⟨' x '∣'" := (bra x).
Notation "∣ x , y , .. , z ⟩" := (kron .. (kron x y) .. z) (at level 0).

Transparent bra.
Transparent ket.
Transparent qubit0.
Transparent qubit1.

Definition bool_to_ket (b : bool) : Matrix 2 1 := if b then ∣1⟩ else ∣0⟩.

Definition bool_to_matrix (b : bool) : Matrix 2 2 := if b then ∣1⟩⟨1∣ else ∣0⟩⟨0∣.

Definition bool_to_matrix' (b : bool) : Matrix 2 2 := fun x y
  match x, y with
  | 0, 0 ⇒ if b then 0 else 1
  | 1, 1 ⇒ if b then 1 else 0
  | _, _ ⇒ 0
  end.

Lemma bool_to_matrix_eq : b, bool_to_matrix b = bool_to_matrix' b.

Lemma bool_to_ket_matrix_eq : b,
    outer_product (bool_to_ket b) (bool_to_ket b) = bool_to_matrix b.

Definition bools_to_matrix (l : list bool) : Square (2^(length l)) :=
  big_kron (map bool_to_matrix l).

Lemma ket_decomposition : (ψ : Vector 2),
  WF_Matrix ψ
  ψ = (ψ 0%nat 0%nat) .* 0 .+ (ψ 1%nat 0%nat) .* 1 .



Definition xbasis_plus : Vector 2 := / ( 2) .* (∣0⟩ .+ ∣1⟩).
Definition xbasis_minus : Vector 2 := / ( 2) .* (∣0⟩ .+ ((-1) .* ∣1⟩)).
Definition ybasis_plus : Vector 2 := / ( 2) .* (∣0⟩ .+ Ci .* ∣1⟩).
Definition ybasis_minus : Vector 2 := / ( 2) .* (∣0⟩ .+ ((-Ci) .* ∣1⟩)).

Notation "∣+⟩" := xbasis_plus.
Notation "∣-⟩" := xbasis_minus.
Notation "∣R⟩" := ybasis_plus.
Notation "∣L⟩" := ybasis_minus.

Definition EPRpair : Vector 4 := / ( 2) .* (0,0 .+ 1,1).

Notation "∣Φ+⟩" := EPRpair.


Unitaries


Definition hadamard : Matrix 2 2 :=
  (fun x ymatch x, y with
          | 0, 0 ⇒ (1 / 2)
          | 0, 1 ⇒ (1 / 2)
          | 1, 0 ⇒ (1 / 2)
          | 1, 1 ⇒ -(1 / 2)
          | _, _ ⇒ 0
          end).

Fixpoint hadamard_k (k : nat) : Matrix (2^k) (2^k):=
  match k with
  | 0 ⇒ I 1
  | S k'hadamard hadamard_k k'
  end.

Lemma hadamard_1 : hadamard_k 1 = hadamard.

Definition σx : Matrix 2 2 :=
  fun x ymatch x, y with
          | 0, 1 ⇒ C1
          | 1, 0 ⇒ C1
          | _, _C0
          end.

Definition σy : Matrix 2 2 :=
  fun x ymatch x, y with
          | 0, 1 ⇒ -Ci
          | 1, 0 ⇒ Ci
          | _, _C0
          end.

Definition σz : Matrix 2 2 :=
  fun x ymatch x, y with
          | 0, 0 ⇒ C1
          | 1, 1 ⇒ -C1
          | _, _C0
          end.

Definition sqrtx : Matrix 2 2 :=
  fun x ymatch x, y with
          | 0, 0 ⇒ (1 + Ci)/2
          | 0, 1 ⇒ (1 - Ci)/2
          | 1, 0 ⇒ (1 - Ci)/2
          | 1, 1 ⇒ (1 + Ci)/2
          | _, _C0
          end.

Lemma sqrtx_sqrtx : sqrtx × sqrtx = σx.

Definition control {n : nat} (A : Matrix n n) : Matrix (2×n) (2×n) :=
  fun x yif (x <? n) && (y =? x) then 1 else
          if (n <=? x) && (n <=? y) then A (x-n)%nat (y-n)%nat else 0.

Definition cnot : Matrix (2×2) (2×2) :=
  fun x ymatch x, y with
          | 0, 0 ⇒ C1
          | 1, 1 ⇒ C1
          | 2, 3 ⇒ C1
          | 3, 2 ⇒ C1
          | _, _C0
          end.

Lemma cnot_eq : cnot = control σx.

Definition notc : Matrix (2×2) (2×2) :=
  fun x ymatch x, y with
          | 1, 3 ⇒ 1%C
          | 3, 1 ⇒ 1%C
          | 0, 0 ⇒ 1%C
          | 2, 2 ⇒ 1%C
          | _, _ ⇒ 0%C
          end.


Definition swap : Matrix (2×2) (2×2) :=
  fun x ymatch x, y with
          | 0, 0 ⇒ C1
          | 1, 2 ⇒ C1
          | 2, 1 ⇒ C1
          | 3, 3 ⇒ C1
          | _, _C0
          end.

#[export] Hint Unfold qubit0 qubit1 hadamard σx σy σz control cnot swap bra ket : U_db.

Rotation Matrices


Definition rotation (θ ϕ λ : R) : Matrix 2 2 :=
  fun x ymatch x, y with
             | 0, 0 ⇒ (cos (θ/2))
             | 0, 1 ⇒ - (Cexp λ) × (sin (θ/2))
             | 1, 0 ⇒ (Cexp ϕ) × (sin (θ/2))
             | 1, 1 ⇒ (Cexp (ϕ + λ)) × (cos (θ/2))
             | _, _C0
             end.

Definition phase_shift (ϕ : R) : Matrix 2 2 :=
  fun x ymatch x, y with
          | 0, 0 ⇒ C1
          | 1, 1 ⇒ Cexp ϕ
          | _, _C0
          end.

Definition x_rotation (θ : R) : Matrix 2 2 :=
  fun x ymatch x, y with
          | 0, 0 ⇒ cos (θ / 2)
          | 0, 1 ⇒ -Ci × sin (θ / 2)
          | 1, 0 ⇒ -Ci × sin (θ / 2)
          | 1, 1 ⇒ cos (θ / 2)
          | _, _ ⇒ 0
          end.

Definition y_rotation (θ : R) : Matrix 2 2 :=
  fun x ymatch x, y with
          | 0, 0 ⇒ cos (θ / 2)
          | 0, 1 ⇒ - sin (θ / 2)
          | 1, 0 ⇒ sin (θ / 2)
          | 1, 1 ⇒ cos (θ / 2)
          | _, _ ⇒ 0
          end.


Definition Sgate : Matrix 2 2 := phase_shift (PI / 2).

Definition Tgate := phase_shift (PI / 4).

Lemma x_rotation_pi : x_rotation PI = -Ci .* σx.

Lemma y_rotation_pi : y_rotation PI = -Ci .* σy.

Lemma hadamard_rotation : rotation (PI/2) 0 PI = hadamard.

Lemma pauli_x_rotation : rotation PI 0 PI = σx.

Lemma pauli_y_rotation : rotation PI (PI/2) (PI/2) = σy.

Lemma pauli_z_rotation : rotation 0 0 PI = σz.

Lemma Rx_rotation : θ, rotation θ (3×PI/2) (PI/2) = x_rotation θ.

Lemma Ry_rotation : θ, rotation θ 0 0 = y_rotation θ.

Lemma phase_shift_rotation : θ, rotation 0 0 θ = phase_shift θ.

Lemma I_rotation : rotation 0 0 0 = I 2.


Lemma sqrtx_decompose: sqrtx = hadamard × phase_shift (PI/2) × hadamard.


Lemma Mmult00 : ⟨0∣ × ∣0⟩ = I 1.
Lemma Mmult01 : ⟨0∣ × ∣1⟩ = Zero.
Lemma Mmult10 : ⟨1∣ × ∣0⟩ = Zero.
Lemma Mmult11 : ⟨1∣ × ∣1⟩ = I 1.

Lemma MmultX1 : σx × ∣1⟩ = ∣0⟩.
Lemma Mmult1X : ⟨1∣ × σx = ⟨0∣.
Lemma MmultX0 : σx × ∣0⟩ = ∣1⟩.
Lemma Mmult0X : ⟨0∣ × σx = ⟨1∣.

Lemma MmultXX : σx × σx = I 2.
Lemma MmultYY : σy × σy = I 2.
Lemma MmultZZ : σz × σz = I 2.
Lemma MmultHH : hadamard × hadamard = I 2.
Lemma Mplus01 : ∣0⟩⟨0∣ .+ ∣1⟩⟨1∣ = I 2.
Lemma Mplus10 : ∣1⟩⟨1∣ .+ ∣0⟩⟨0∣ = I 2.

Lemma EPRpair_creation : cnot × (hadamard I 2) × 0,0 = EPRpair.

Lemma σx_on_right0 : (q : Vector 2), (q × ⟨0∣) × σx = q × ⟨1∣.

Lemma σx_on_right1 : (q : Vector 2), (q × ⟨1∣) × σx = q × ⟨0∣.

Lemma σx_on_left0 : (q : Matrix 1 2), σx × (∣0⟩ × q) = ∣1⟩ × q.

Lemma σx_on_left1 : (q : Matrix 1 2), σx × (∣1⟩ × q) = ∣0⟩ × q.

Lemma cancel00 : (q1 : Matrix 2 1) (q2 : Matrix 1 2),
  WF_Matrix q2
  (q1 × ⟨0∣) × (∣0⟩ × q2) = q1 × q2.

Lemma cancel01 : (q1 : Matrix 2 1) (q2 : Matrix 1 2),
  (q1 × ⟨0∣) × (∣1⟩ × q2) = Zero.

Lemma cancel10 : (q1 : Matrix 2 1) (q2 : Matrix 1 2),
  (q1 × ⟨1∣) × (∣0⟩ × q2) = Zero.

Lemma cancel11 : (q1 : Matrix 2 1) (q2 : Matrix 1 2),
  WF_Matrix q2
  (q1 × ⟨1∣) × (∣1⟩ × q2) = q1 × q2.

Hint Rewrite Mmult00 Mmult01 Mmult10 Mmult11 Mmult0X MmultX0 Mmult1X MmultX1 : Q_db.
Hint Rewrite MmultXX MmultYY MmultZZ MmultHH Mplus01 Mplus10 EPRpair_creation : Q_db.
Hint Rewrite σx_on_right0 σx_on_right1 σx_on_left0 σx_on_left1 : Q_db.
Hint Rewrite cancel00 cancel01 cancel10 cancel11 using (auto with wf_db) : Q_db.

Lemma swap_swap : swap × swap = I (2×2).

Lemma swap_swap_r : (A : Matrix (2×2) (2×2)),
  WF_Matrix A
  A × swap × swap = A.

Hint Rewrite swap_swap swap_swap_r using (auto 100 with wf_db): Q_db.


Fixpoint swap_to_0_aux (n i : nat) {struct i} : Matrix (2^n) (2^n) :=
  match i with
  | Oswap I (2^(n-2))
  | S i'(I (2^i) swap I (2^(n-i-2))) ×
            swap_to_0_aux n i' ×
            (I (2^i) swap I (2^(n-i-2)))
  end.

Definition swap_to_0 (n i : nat) : Matrix (2^n) (2^n) :=
  match i with
  | OI (2^n)
  | S i'swap_to_0_aux n i'
  end.

Fixpoint swap_two_aux (n i j : nat) : Matrix (2^n) (2^n) :=
  match i with
  | Oswap_to_0 n j
  | S i'I 2 swap_two_aux (n-1) (i') (j-1)
  end.

Definition swap_two (n i j : nat) : Matrix (2^n) (2^n) :=
  if i =? j then I (2^n)
  else if i <? j then swap_two_aux n i j
  else swap_two_aux n j i.

Fixpoint move_to_0_aux (n i : nat) {struct i}: Matrix (2^n) (2^n) :=
  match i with
  | Oswap I (2^(n-2))
  | S i'(move_to_0_aux n i') × (I (2^i) swap I (2^(n-i-2)))
                  
  end.

Definition move_to_0 (n i : nat) : Matrix (2^n) (2^n) :=
  match i with
  | OI (2^n)
  | S i'move_to_0_aux n i'
  end.

Fixpoint move_to (n i k : nat) : Matrix (2^n) (2^n) :=
  match k with
  | Omove_to_0 n i
  | S k'I 2 move_to (n-1) (i-1) (k')
  end.


Well Formedness of Quantum States and Unitaries

Lemma WF_bra0 : WF_Matrix ⟨0∣.
Lemma WF_bra1 : WF_Matrix ⟨1∣.
Lemma WF_qubit0 : WF_Matrix ∣0⟩.
Lemma WF_qubit1 : WF_Matrix ∣1⟩.
Lemma WF_braqubit0 : WF_Matrix ∣0⟩⟨0∣.
Lemma WF_braqubit1 : WF_Matrix ∣1⟩⟨1∣.
Lemma WF_bool_to_ket : b, WF_Matrix (bool_to_ket b).
Lemma WF_bool_to_matrix : b, WF_Matrix (bool_to_matrix b).
Lemma WF_bool_to_matrix' : b, WF_Matrix (bool_to_matrix' b).

Lemma WF_ket : n, WF_Matrix (ket n).
Lemma WF_bra : n, WF_Matrix (bra n).

Lemma WF_bools_to_matrix : l,
  @WF_Matrix (2^(length l)) (2^(length l)) (bools_to_matrix l).

Lemma WF_xbasis_plus : WF_Matrix ∣+⟩.
Lemma WF_xbasis_minus : WF_Matrix ∣-⟩.
Lemma WF_ybasis_plus : WF_Matrix R.
Lemma WF_ybasis_minus : WF_Matrix L.

#[export] Hint Resolve WF_bra0 WF_bra1 WF_qubit0 WF_qubit1 WF_braqubit0 WF_braqubit1 : wf_db.
#[export] Hint Resolve WF_bool_to_ket WF_bool_to_matrix WF_bool_to_matrix' : wf_db.
#[export] Hint Resolve WF_ket WF_bra WF_bools_to_matrix : wf_db.
#[export] Hint Resolve WF_xbasis_plus WF_xbasis_minus WF_ybasis_plus WF_ybasis_minus : wf_db.

Lemma WF_EPRpair : WF_Matrix Φ+⟩.

#[export] Hint Resolve WF_EPRpair : wf_db.

Lemma WF_hadamard : WF_Matrix hadamard.
Lemma WF_σx : WF_Matrix σx.
Lemma WF_σy : WF_Matrix σy.
Lemma WF_σz : WF_Matrix σz.
Lemma WF_cnot : WF_Matrix cnot.
Lemma WF_notc : WF_Matrix notc.
Lemma WF_swap : WF_Matrix swap.

Lemma WF_rotation : θ ϕ λ, WF_Matrix (rotation θ ϕ λ).
Lemma WF_phase : ϕ, WF_Matrix (phase_shift ϕ).

Lemma WF_Sgate : WF_Matrix Sgate.
Lemma WF_Tgate: WF_Matrix Tgate.

Lemma WF_control : (n : nat) (U : Matrix n n),
      WF_Matrix U WF_Matrix (control U).

#[export] Hint Resolve WF_hadamard WF_σx WF_σy WF_σz WF_cnot WF_notc WF_swap : wf_db.
#[export] Hint Resolve WF_phase WF_Sgate WF_Tgate WF_rotation : wf_db.

#[export] Hint Extern 2 (WF_Matrix (phase_shift _)) ⇒ apply WF_phase : wf_db.
#[export] Hint Extern 2 (WF_Matrix (control _)) ⇒ apply WF_control : wf_db.

Lemma direct_sum_decomp : (m n o p : nat) (A B : Matrix m n),
  WF_Matrix A WF_Matrix B
  A .⊕ B = ∣0⟩⟨0∣ A .+ ∣1⟩⟨1∣ B.

Unitaries are unitary


Definition WF_Unitary {n: nat} (U : Matrix n n): Prop :=
  WF_Matrix U U × U = I n.

#[export] Hint Unfold WF_Unitary : U_db.


Lemma H_unitary : WF_Unitary hadamard.

Lemma σx_unitary : WF_Unitary σx.

Lemma σy_unitary : WF_Unitary σy.

Lemma σz_unitary : WF_Unitary σz.

Lemma phase_unitary : ϕ, @WF_Unitary 2 (phase_shift ϕ).

Lemma S_unitary : WF_Unitary Sgate.

Lemma T_unitary : WF_Unitary Tgate.

Lemma rotation_unitary : θ ϕ λ, @WF_Unitary 2 (rotation θ ϕ λ).

Lemma x_rotation_unitary : θ, @WF_Unitary 2 (x_rotation θ).

Lemma y_rotation_unitary : θ, @WF_Unitary 2 (y_rotation θ).

Lemma control_unitary : n (A : Matrix n n),
                          WF_Unitary A WF_Unitary (control A).

#[export] Hint Resolve H_unitary S_unitary T_unitary σx_unitary σy_unitary σz_unitary : unit_db.
#[export] Hint Resolve phase_unitary rotation_unitary x_rotation_unitary y_rotation_unitary control_unitary : unit_db.


Lemma transpose_unitary : n (A : Matrix n n), WF_Unitary A WF_Unitary (A).

Lemma cnot_unitary : WF_Unitary cnot.

Lemma notc_unitary : WF_Unitary notc.

Lemma id_unitary : n, WF_Unitary (I n).

Lemma swap_unitary : WF_Unitary swap.

Lemma zero_not_unitary : n, ¬ (WF_Unitary (@Zero (2^n) (2^n))).

Lemma kron_unitary : {m n} (A : Matrix m m) (B : Matrix n n),
  WF_Unitary A WF_Unitary B WF_Unitary (A B).

Lemma big_kron_unitary : (n : nat) (ls : list (Square n)),
  ( a, In a ls WF_Unitary a) WF_Unitary ( ls).

Lemma big_kron_unitary' : (n m : nat) (ls : list (Square n)),
  length ls = m ( a, In a ls WF_Unitary a) @WF_Unitary (n^m) ( ls).

Lemma Mmult_unitary : (n : nat) (A : Square n) (B : Square n),
  WF_Unitary A
  WF_Unitary B
  WF_Unitary (A × B).

Lemma scale_unitary : (n : nat) (c : C) (A : Square n),
  WF_Unitary A
  (c × c ^*)%C = C1
  WF_Unitary (c .* A).

#[export] Hint Resolve transpose_unitary cnot_unitary notc_unitary id_unitary : unit_db.
#[export] Hint Resolve swap_unitary zero_not_unitary kron_unitary big_kron_unitary big_kron_unitary' Mmult_unitary scale_unitary : unit_db.

Lemma hadamard_st : hadamard = hadamard.

Lemma adjoint_transpose_comm : m n (A : Matrix m n),
  A = A .



Definition id_sa := id_adjoint_eq.

Lemma hadamard_sa : hadamard = hadamard.

Lemma σx_sa : σx = σx.

Lemma σy_sa : σy = σy.

Lemma σz_sa : σz = σz.

Lemma cnot_sa : cnot = cnot.

Lemma swap_sa : swap = swap.

Lemma control_adjoint : n (U : Square n), (control U)† = control (U).

Lemma control_sa : (n : nat) (A : Square n),
    A = A (control A)† = (control A).

Lemma phase_adjoint : ϕ, (phase_shift ϕ)† = phase_shift (-ϕ).


Lemma rotation_adjoint : θ ϕ λ, (rotation θ ϕ λ)† = rotation (-θ) (-λ) (-ϕ).

Lemma braqubit0_sa : ∣0⟩⟨0∣ = ∣0⟩⟨0∣.
Lemma braqubit1_sa : ∣1⟩⟨1∣ = ∣1⟩⟨1∣.

Hint Rewrite hadamard_sa σx_sa σy_sa σz_sa cnot_sa swap_sa braqubit1_sa braqubit0_sa control_adjoint phase_adjoint rotation_adjoint : Q_db.


Lemma cnot_decomposition : ∣1⟩⟨1∣ σx .+ ∣0⟩⟨0∣ I 2 = cnot.

Lemma notc_decomposition : σx ∣1⟩⟨1∣ .+ I 2 ∣0⟩⟨0∣ = notc.

Phase Lemmas


Lemma phase_0 : phase_shift 0 = I 2.

Lemma phase_2pi : phase_shift (2 × PI) = I 2.

Lemma phase_pi : phase_shift PI = σz.

Lemma phase_neg_pi : phase_shift (-PI) = σz.

Lemma phase_mul : θ θ', phase_shift θ × phase_shift θ' = phase_shift (θ + θ').

Lemma phase_PI4_m8 : k,
  phase_shift (IZR k × PI / 4) = phase_shift (IZR (k - 8) × PI / 4).

Lemma phase_mod_2PI : k, phase_shift (IZR k × PI) = phase_shift (IZR (k mod 2) × PI).

Lemma phase_mod_2PI_scaled : (k sc : Z),
  sc 0%Z
  phase_shift (IZR k × PI / IZR sc) = phase_shift (IZR (k mod (2 × sc)) × PI / IZR sc).

Hint Rewrite phase_0 phase_2pi phase_pi phase_neg_pi : Q_db.


Definition positive_semidefinite {n} (A : Square n) : Prop :=
   (z : Vector n), WF_Matrix z fst ((z × A × z) O O) 0.

Lemma pure_psd : (n : nat) (ϕ : Vector n), (WF_Matrix ϕ) positive_semidefinite (ϕ × ϕ).

Lemma braket0_psd : positive_semidefinite ∣0⟩⟨0∣.

Lemma braket1_psd : positive_semidefinite ∣1⟩⟨1∣.

Lemma H0_psd : positive_semidefinite (hadamard × ∣0⟩⟨0∣ × hadamard).


Notation Density n := (Matrix n n) (only parsing).

Definition Classical {n} (ρ : Density n) := i j, i j ρ i j = 0.

Definition Pure_State_Vector {n} (φ : Vector n): Prop :=
  WF_Matrix φ φ × φ = I 1.

Definition Pure_State {n} (ρ : Density n) : Prop :=
   φ, Pure_State_Vector φ ρ = φ × φ.

Inductive Mixed_State {n} : Matrix n n Prop :=
| Pure_S : ρ, Pure_State ρ Mixed_State ρ
| Mix_S : (p : R) ρ1 ρ2, 0 < p < 1 Mixed_State ρ1 Mixed_State ρ2
                                       Mixed_State (p .* ρ1 .+ (1-p)%R .* ρ2).

Lemma WF_Pure : {n} (ρ : Density n), Pure_State ρ WF_Matrix ρ.
#[export] Hint Resolve WF_Pure : wf_db.

Lemma WF_Mixed : {n} (ρ : Density n), Mixed_State ρ WF_Matrix ρ.
#[export] Hint Resolve WF_Mixed : wf_db.

Lemma pure0 : Pure_State ∣0⟩⟨0∣.

Lemma pure1 : Pure_State ∣1⟩⟨1∣.

Lemma pure_id1 : Pure_State (I 1).

Lemma pure_dim1 : (ρ : Square 1), Pure_State ρ ρ = I 1.

Lemma pure_state_unitary_pres : {n} (ϕ : Vector n) (U : Square n),
  Pure_State_Vector ϕ WF_Unitary U Pure_State_Vector (U × ϕ).

Lemma pure_state_vector_kron : {n m} (ϕ : Vector n) (ψ : Vector m),
  Pure_State_Vector ϕ Pure_State_Vector ψ Pure_State_Vector (ϕ ψ).

Lemma pure_state_kron : m n (ρ : Square m) (φ : Square n),
  Pure_State ρ Pure_State φ Pure_State (ρ φ).

Lemma mixed_state_kron : m n (ρ : Square m) (φ : Square n),
  Mixed_State ρ Mixed_State φ Mixed_State (ρ φ).

Lemma pure_state_trace_1 : {n} (ρ : Density n), Pure_State ρ trace ρ = 1.

Lemma mixed_state_trace_1 : {n} (ρ : Density n), Mixed_State ρ trace ρ = 1.


Lemma mixed_state_diag_in01 : {n} (ρ : Density n) i , Mixed_State ρ
                                                        0 fst (ρ i i) 1.

Lemma mixed_state_diag_real : {n} (ρ : Density n) i , Mixed_State ρ
                                                        snd (ρ i i) = 0.

Lemma mixed_dim1 : (ρ : Square 1), Mixed_State ρ ρ = I 1.


Definition norm {n} (ψ : Vector n) : R :=
  sqrt (fst ((ψ × ψ) O O)).

Lemma norm_real : {n} (v : Vector n), snd ((v × v) 0%nat 0%nat) = 0%R.

Definition normalize {n} (ψ : Vector n) :=
  / (norm ψ) .* ψ.

Lemma inner_product_ge_0 : {d} (ψ : Vector d),
  0 fst ((ψ × ψ) O O).

Lemma norm_scale : {n} c (v : Vector n), norm (c .* v) = ((Cmod c) × norm v)%R.

Lemma div_real : (c : C),
  snd c = 0 snd (/ c) = 0.

Lemma Cmod_real : (c : C),
  fst c 0 snd c = 0 Cmod c = fst c.

Lemma normalized_norm_1 : {n} (v : Vector n),
  norm v 0 norm (normalize v) = 1.

Lemma rewrite_norm : {d} (ψ : Vector d),
    fst (((ψ) × ψ) O O) = big_sum (fun iCmod (ψ i O) ^ 2)%R d.

Lemma norm_zero_iff_zero : {n} (v : Vector n),
  WF_Matrix v (norm v = 0%R v = Zero).

Density matrices and superoperators

Definition Superoperator m n := Density m Density n.

Definition WF_Superoperator {m n} (f : Superoperator m n) :=
  ( ρ, Mixed_State ρ Mixed_State (f ρ)).

Definition super {m n} (M : Matrix m n) : Superoperator n m := fun ρ
  M × ρ × M.

Lemma super_I : n ρ,
      WF_Matrix ρ
      super (I n) ρ = ρ.

Lemma WF_super : m n (U : Matrix m n) (ρ : Square n),
  WF_Matrix U WF_Matrix ρ WF_Matrix (super U ρ).

#[export] Hint Resolve WF_super : wf_db.

Lemma super_outer_product : m (φ : Matrix m 1) (U : Matrix m m),
    super U (outer_product φ φ) = outer_product (U × φ) (U × φ).

Definition compose_super {m n p} (g : Superoperator n p) (f : Superoperator m n)
                      : Superoperator m p := fun ρg (f ρ).

Lemma WF_compose_super : m n p (g : Superoperator n p) (f : Superoperator m n)
  (ρ : Square m),
  WF_Matrix ρ
  ( A, WF_Matrix A WF_Matrix (f A))
  ( A, WF_Matrix A WF_Matrix (g A))
  WF_Matrix (compose_super g f ρ).

#[export] Hint Resolve WF_compose_super : wf_db.

Lemma compose_super_correct : {m n p}
                              (g : Superoperator n p) (f : Superoperator m n),
      WF_Superoperator g
      WF_Superoperator f
      WF_Superoperator (compose_super g f).

Definition sum_super {m n} (f g : Superoperator m n) : Superoperator m n :=
  fun ρ ⇒ (1/2)%R .* f ρ .+ (1 - 1/2)%R .* g ρ.

Lemma sum_super_correct : m n (f g : Superoperator m n),
      WF_Superoperator f WF_Superoperator g WF_Superoperator (sum_super f g).

Definition SZero {m n} : Superoperator m n := fun ρZero.
Definition Splus {m n} (S T : Superoperator m n) : Superoperator m n :=
  fun ρS ρ .+ T ρ.

Definition new0_op : Superoperator 1 2 := super ∣0⟩.
Definition new1_op : Superoperator 1 2 := super ∣1⟩.
Definition meas_op : Superoperator 2 2 := Splus (super ∣0⟩⟨0∣) (super ∣1⟩⟨1∣).
Definition discard_op : Superoperator 2 1 := Splus (super ⟨0∣) (super ⟨1∣).

Lemma pure_unitary : {n} (U ρ : Matrix n n),
  WF_Unitary U Pure_State ρ Pure_State (super U ρ).

Lemma mixed_unitary : {n} (U ρ : Matrix n n),
  WF_Unitary U Mixed_State ρ Mixed_State (super U ρ).

Lemma super_unitary_correct : {n} (U : Matrix n n),
  WF_Unitary U WF_Superoperator (super U).

Lemma compose_super_assoc : {m n p q}
      (f : Superoperator m n) (g : Superoperator n p) (h : Superoperator p q),
      compose_super (compose_super f g) h
    = compose_super f (compose_super g h).

Lemma compose_super_eq : {m n p} (A : Matrix m n) (B : Matrix n p),
      compose_super (super A) (super B) = super (A × B).


Ltac Qsimpl := try restore_dims; autorewrite with M_db_light M_db Q_db.


Lemma swap_spec : (q q' : Vector 2), WF_Matrix q
                                       WF_Matrix q'
                                       swap × (q q') = q' q.

Hint Rewrite swap_spec using (auto 100 with wf_db) : Q_db.

Example swap_to_0_test_24 : (q0 q1 q2 q3 : Vector 2),
  WF_Matrix q0 WF_Matrix q1 WF_Matrix q2 WF_Matrix q3
  swap_to_0 4 2 × (q0 q1 q2 q3) = (q2 q1 q0 q3).

Lemma swap_two_base : swap_two 2 1 0 = swap.

Lemma swap_second_two : swap_two 3 1 2 = I 2 swap.

Lemma swap_0_2 : swap_two 3 0 2 = (I 2 swap) × (swap I 2) × (I 2 swap).


Example move_to_0_test_24 : (q0 q1 q2 q3 : Vector 2),
  WF_Matrix q0 WF_Matrix q1 WF_Matrix q2 WF_Matrix q3
  move_to_0 4 2 × (q0 q1 q2 q3) = (q2 q0 q1 q3).