QuantumLib.VectorStates

Require Export Pad.



Notation "∣ + ⟩" := (/2 .* 0 .+ /2 .* 1 ).
Notation "∣ - ⟩" := (/2 .* 0 .+ (-/2) .* 1 ).


Lemma bra0_equiv : ⟨0∣ = bra 0.

Lemma bra1_equiv : ⟨1∣ = bra 1.

Lemma ket0_equiv : ∣0⟩ = ket 0.

Lemma ket1_equiv : ∣1⟩ = ket 1.

Lemma bra0ket0 : bra 0 × ket 0 = I 1.

Lemma bra0ket1 : bra 0 × ket 1 = Zero.

Lemma bra1ket0 : bra 1 × ket 0 = Zero.

Lemma bra1ket1 : bra 1 × ket 1 = I 1.

Lemma H0_spec : hadamard × 0 = + .

Lemma H1_spec : hadamard × 1 = - .

Lemma Hplus_spec : hadamard × + = 0 .

Lemma Hminus_spec : hadamard × - = 1 .

Local Open Scope nat_scope.

Lemma H0_kron_n_spec : n,
  n hadamard × n ∣0⟩ = n ∣+⟩.

Local Close Scope nat_scope.

Lemma X0_spec : σx × 0 = 1 .

Lemma X1_spec : σx × 1 = 0 .

Lemma Y0_spec : σy × 0 = Ci .* 1 .

Lemma Y1_spec : σy × 1 = -Ci .* 0 .

Lemma Z0_spec : σz × 0 = 0 .

Lemma Z1_spec : σz × 1 = -1 .* 1 .

Lemma phase0_spec : ϕ, phase_shift ϕ × ket 0 = ket 0.

Lemma phase1_spec : ϕ, phase_shift ϕ × ket 1 = Cexp ϕ .* ket 1.

Definition b2R (b : bool) : R := if b then 1%R else 0%R.

Lemma phase_shift_on_ket : (θ : R) (b : bool),
  phase_shift θ × b = (Cexp (b × θ)) .* b .

Lemma hadamard_on_ket : (b : bool),
  hadamard × b = /2 .* ( 0 .+ (-1)^b .* 1 ).


Lemma CNOT_spec : (x y : nat), (x < 2)%nat (y < 2)%nat cnot × x,y = x, (x + y) mod 2 .

Lemma CNOT00_spec : cnot × 0,0 = 0,0 .

Lemma CNOT01_spec : cnot × 0,1 = 0,1 .

Lemma CNOT10_spec : cnot × 1,0 = 1,1 .

Lemma CNOT11_spec : cnot × 1,1 = 1,0 .


Lemma SWAP_spec : x y, swap × x,y = y,x .


Hint Rewrite bra0_equiv bra1_equiv ket0_equiv ket1_equiv : ket_db.
Hint Rewrite bra0ket0 bra0ket1 bra1ket0 bra1ket1 : ket_db.
Hint Rewrite Mmult_plus_distr_l Mmult_plus_distr_r kron_plus_distr_l kron_plus_distr_r Mscale_plus_distr_r : ket_db.
Hint Rewrite Mscale_mult_dist_l Mscale_mult_dist_r Mscale_kron_dist_l Mscale_kron_dist_r : ket_db.
Hint Rewrite Mscale_assoc @Mmult_assoc : ket_db.
Hint Rewrite Mmult_1_l Mmult_1_r kron_1_l kron_1_r Mscale_0_l Mscale_0_r Mscale_1_l Mplus_0_l Mplus_0_r using (auto with wf_db) : ket_db.
Hint Rewrite kron_0_l kron_0_r Mmult_0_l Mmult_0_r : ket_db.
Hint Rewrite @kron_mixed_product : ket_db.

Hint Rewrite H0_spec H1_spec Hplus_spec Hminus_spec X0_spec X1_spec Y0_spec Y1_spec
     Z0_spec Z1_spec phase0_spec phase1_spec : ket_db.
Hint Rewrite CNOT00_spec CNOT01_spec CNOT10_spec CNOT11_spec SWAP_spec : ket_db.

Lemma ket2bra : n, (ket n) = bra n.
Hint Rewrite ket2bra : ket_db.

Lemma ket0_transpose_bra0 : (ket 0) = bra 0.

Lemma ket1_transpose_bra1 : (ket 1) = bra 1.

Lemma bra0_transpose_ket0 : (bra 0) = ket 0.

Lemma bra1_transpose_ket1 : (bra 1) = ket 1.

Lemma bra1_adjoint_ket1 : (bra 1) = ket 1.

Lemma ket1_adjoint_bra1 : (ket 1) = bra 1.

Lemma bra0_adjoint_ket0 : (bra 0) = ket 0.

Lemma ket0_adjoint_bra0 : (ket 0) = bra 0.

Lemma XYZ0 : -Ci .* σx × σy × σz × 0 = 0 .

Lemma XYZ1 : -Ci .* σx × σy × σz × 1 = 1 .

Classical States

Local Close Scope C_scope.
Local Close Scope R_scope.
Local Open Scope nat_scope.



Definition update {A} (f : nat A) (i : nat) (x : A) :=
  fun jif j =? i then x else f j.

Lemma update_index_eq : {A} (f : nat A) i b, (update f i b) i = b.

Lemma update_index_neq : {A} (f : nat A) i j b, i j (update f i b) j = f j.

Lemma update_same : {A} (f : nat A) i b,
  b = f i update f i b = f.

Lemma update_twice_eq : {A} (f : nat A) i b b',
  update (update f i b) i b' = update f i b'.

Lemma update_twice_neq : {A} (f : nat A) i j b b',
  i j update (update f i b) j b' = update (update f j b') i b.

Definition shift {A} (f : nat A) k := fun if (i + k).

Lemma shift_0 : {A} (f : nat A), shift f 0 = f.

Lemma shift_plus : {A} (f : nat A) i j, shift (shift f j) i = shift f (i + j).

Lemma shift_simplify : {A} (f : nat A) i j ,
  shift f i j = f (j + i).

Definition fswap {A} (f : nat A) x y :=
  fun iif i =? x then f y else if i =? y then f x else f i.

Lemma fswap_simpl1 : A f x y, @fswap A f x y x = f y.

Lemma fswap_simpl2 : A f x y, @fswap A f x y y = f x.

Lemma fswap_same : A f x, @fswap A f x x = f.

Lemma fswap_neq : {A} (f : nat A) a b x, a x b x fswap f a b x = f x.

Lemma fswap_rewrite : {A} (f : nat A) a b,
  fswap f a b = update (update f b (f a)) a (f b).

Fixpoint f_to_vec (n : nat) (f : nat bool) : Vector (2^n) :=
  match n with
  | 0 ⇒ I 1
  | S n'(f_to_vec n' f) f n'
  end.

Lemma f_to_vec_WF : (n : nat) (f : nat bool),
  WF_Matrix (f_to_vec n f).
#[export] Hint Resolve f_to_vec_WF : wf_db.

Lemma f_to_vec_eq : n f f',
  ( i, i < n f i = f' i)
  f_to_vec n f = f_to_vec n f'.


Definition basis_vector (n k : nat) : Vector n :=
  fun i jif (i =? k) && (j =? 0) then C1 else C0.

Lemma basis_vector_WF : n i, (i < n)%nat WF_Matrix (basis_vector n i).
#[export] Hint Resolve basis_vector_WF : wf_db.

Lemma basis_vector_product_eq : d n,
  n < d (basis_vector d n)† × basis_vector d n = I 1.

Lemma basis_vector_pure_state : n i,
  (i < n)%nat Pure_State_Vector (basis_vector n i).

Lemma basis_vector_product_neq : d m n,
  (m < d)%nat (n < d)%nat (m n)%nat (basis_vector d m)† × basis_vector d n = Zero.

Lemma matrix_times_basis_eq : m n (A : Matrix m n) i j,
  WF_Matrix A
  (A × basis_vector n j) i 0 = A i j.

Lemma equal_on_basis_vectors_implies_equal : m n (A B : Matrix m n),
  WF_Matrix A
  WF_Matrix B
  ( k, k < n A × (basis_vector n k) = B × (basis_vector n k))
  A = B.

Lemma divmod_decomp : x y z r,
    (r > 0)%nat
    (z < r)%nat
    (x = y × r + z x / r = y x mod r = z)%nat.

Lemma split_basis_vector : m n x y,
  (x < 2 ^ m)%nat
  (y < 2 ^ n)%nat
  basis_vector (2 ^ (m + n)) (x × 2 ^ n + y)
    = basis_vector (2 ^ m) x basis_vector (2 ^ n) y.


Fixpoint binlist_to_nat (l : list bool) : nat :=
  match l with
  | [] ⇒ 0
  | b :: l'b + 2 × binlist_to_nat l'
  end.

Fixpoint funbool_to_list (len : nat) (f : nat bool) :=
  match len with
  | O[]
  | S len'f len' :: funbool_to_list len' f
  end.

Definition funbool_to_nat (len : nat) (f : nat bool) :=
  binlist_to_nat (funbool_to_list len f).

Lemma funbool_to_nat_bound : n f, (funbool_to_nat n f < 2 ^ n)%nat.

Lemma funbool_to_nat_eq : n f f',
  ( x, x < n f x = f' x)%nat
  funbool_to_nat n f = funbool_to_nat n f'.

Local Opaque Nat.mul.
Lemma funbool_to_nat_shift : n f k, (k < n)%nat
  funbool_to_nat n f = (2 ^ (n - k) × funbool_to_nat k f + funbool_to_nat (n - k) (shift f k))%nat.
Local Transparent Nat.mul.

Lemma basis_f_to_vec : n f,
  f_to_vec n f = basis_vector (2^n) (funbool_to_nat n f).

Fixpoint incr_bin (l : list bool) :=
  match l with
  | [][true]
  | false :: ttrue :: t
  | true :: tfalse :: (incr_bin t)
  end.

Fixpoint nat_to_binlist' n :=
  match n with
  | O[]
  | S n'incr_bin (nat_to_binlist' n')
  end.
Definition nat_to_binlist len n :=
  let l := nat_to_binlist' n in
  l ++ (repeat false (len - length l)).

Fixpoint list_to_funbool len (l : list bool) : nat bool :=
  match l with
  | []fun _false
  | h :: tupdate (list_to_funbool (len - 1)%nat t) (len - 1) h
  end.

Definition nat_to_funbool len n : nat bool :=
  list_to_funbool len (nat_to_binlist len n).

Lemma binlist_to_nat_append : l1 l2,
  binlist_to_nat (l1 ++ l2) =
    (binlist_to_nat l1 + 2 ^ (length l1) × binlist_to_nat l2)%nat.

Lemma binlist_to_nat_false : n, binlist_to_nat (repeat false n) = O.

Lemma binlist_to_nat_true : n, binlist_to_nat (repeat true n) = 2^n - 1.

Lemma nat_to_binlist_eq_nat_to_binlist' : len n,
  binlist_to_nat (nat_to_binlist len n) = binlist_to_nat (nat_to_binlist' n).

Lemma nat_to_binlist_inverse : len n,
  binlist_to_nat (nat_to_binlist len n) = n.

Lemma nat_to_binlist_corr : l n,
   nat_to_binlist' n = l
   binlist_to_nat l = n.
Lemma incr_bin_true_length : l,
  Forall (fun bb = true) l
  length (incr_bin l) = S (length l).

Lemma incr_bin_false_length : l,
  Exists (fun bb true) l
  length (incr_bin l) = length l.

Lemma all_true_repeat : l,
  Forall (fun b : boolb = true) l
  l = repeat true (length l).

Lemma nat_to_binlist_length' : k n,
    n < 2 ^ k length (nat_to_binlist' n) k.

Lemma nat_to_binlist_length : len n,
  (n < 2 ^ len)%nat length (nat_to_binlist len n) = len.

Lemma funbool_to_list_update_oob : f dim b n, (dim n)%nat
  funbool_to_list dim (update f n b) = funbool_to_list dim f.

Lemma list_to_funbool_inverse : len l,
  length l = len
  funbool_to_list len (list_to_funbool len l) = l.

Lemma nat_to_funbool_inverse : len n,
  (n < 2 ^ len)%nat funbool_to_nat len (nat_to_funbool len n) = n.

Local Opaque Nat.mul.
Lemma nat_to_binlist'_even : n, (n > 0)%nat
  nat_to_binlist' (2 × n) = false :: nat_to_binlist' n.

Lemma nat_to_binlist'_odd : n,
  nat_to_binlist' (2 × n + 1) = true :: nat_to_binlist' n.

Lemma binlist_to_nat_inverse : l n i,
  list_to_funbool n (nat_to_binlist' (binlist_to_nat l)) i = list_to_funbool n l i.

Lemma list_to_funbool_repeat_false : n i,
  list_to_funbool n (repeat false n) i = false.

Lemma funbool_to_nat_0 : n f,
  funbool_to_nat n f = O i, (i < n)%nat f i = false.

Lemma funbool_to_nat_inverse : len f i, (i < len)%nat
  nat_to_funbool len (funbool_to_nat len f) i = f i.
Local Transparent Nat.mul.

Lemma basis_f_to_vec_alt : len n, (n < 2 ^ len)%nat
  basis_vector (2 ^ len) n = f_to_vec len (nat_to_funbool len n).

Lemma equal_on_basis_states_implies_equal : {dim} (A B : Square (2 ^ dim)),
  WF_Matrix A
  WF_Matrix B
  ( f, A × (f_to_vec dim f) = B × (f_to_vec dim f))
  A = B.

Lemma f_to_vec_update_oob : (n : nat) (f : nat bool) (i : nat) (b : bool),
  n i f_to_vec n (update f i b) = f_to_vec n f.

Lemma f_to_vec_shift_update_oob : (n : nat) (f : nat bool) (i j : nat) (b : bool),
  j + n i i < j
  f_to_vec n (shift (update f i b) j) = f_to_vec n (shift f j).

Lemma f_to_vec_split : (base n i : nat) (f : nat bool),
  i < n
  f_to_vec n f = (f_to_vec i f) f i (f_to_vec (n - 1 - i) (shift f (i + 1))).

Lemma f_to_vec_merge : f1 f2 m n,
  f_to_vec m f1 f_to_vec n f2 =
    f_to_vec (m + n) (fun xif x <? m then f1 x else f2 (x - m)%nat).


Lemma f_to_vec_σx : (n i : nat) (f : nat bool),
  i < n
  (pad_u n i σx) × (f_to_vec n f) = f_to_vec n (update f i (¬ (f i))).

Lemma f_to_vec_cnot : (n i j : nat) (f : nat bool),
  i < n j < n i j
  (pad_ctrl n i j σx) × (f_to_vec n f) = f_to_vec n (update f j (f j f i)).

Lemma f_to_vec_swap : (n i j : nat) (f : nat bool),
  i < n j < n i j
  (pad_swap n i j) × (f_to_vec n f) = f_to_vec n (fswap f i j).

Lemma f_to_vec_phase_shift : (n i : nat) (θ : R) (f : nat bool),
  (i < n)%nat
  (pad_u n i (phase_shift θ)) × (f_to_vec n f) =
    (Cexp ((f i) × θ)) .* f_to_vec n f.

Local Open Scope R_scope.

Lemma f_to_vec_hadamard : (n i : nat) (f : nat bool),
  (i < n)%nat
  (pad_u n i hadamard) × (f_to_vec n f)
      = /2 .* ((f_to_vec n (update f i false)) .+
                (Cexp ((f i) × PI)) .* f_to_vec n (update f i true)).

Local Close Scope R_scope.

Hint Rewrite f_to_vec_cnot f_to_vec_σx f_to_vec_phase_shift using lia : f_to_vec_db.
Hint Rewrite (@update_index_eq bool) (@update_index_neq bool) (@update_twice_eq bool) (@update_same bool) using lia : f_to_vec_db.

Indexed Vector Sum

Lemma basis_vector_decomp : {d} (ψ : Vector d),
  WF_Matrix ψ
  ψ = big_sum (fun i(ψ i O) .* basis_vector d i) d.

Local Opaque Nat.mul.
Lemma vsum_sum : d n (f : nat Vector d),
  big_sum f (2 × n) =
  big_sum (fun if (2 × i)%nat) n .+ big_sum (fun if (2 × i + 1)%nat) n.
Local Transparent Nat.mul.

Lemma vsum_split : {d} (n i : nat) (v : nat Vector d),
  (i < n)%nat
  big_sum v n = (big_sum v i) .+ v i .+ (big_sum (shift v (i + 1)) (n - 1 - i)).

Lemma vsum_eq_up_to_fswap : {d} n f (v : nat Vector d) x y,
  (x < n)%nat (y < n)%nat
  big_sum (fun iv (f i)) n = big_sum (fun iv (fswap f x y i)) n.

Indexed Kronecker Product

Fixpoint vkron n (f : nat Vector 2) : Vector (2 ^ n) :=
  match n with
  | 0 ⇒ I 1
  | S n'vkron n' f f n'
  end.

Lemma WF_vkron : n (f : nat Vector 2),
  ( i, (i < n)%nat WF_Matrix (f i))
  WF_Matrix (vkron n f).
#[export] Hint Resolve WF_vkron: wf_db.

Lemma WF_shift : m n j k (f : nat Matrix m n),
  ( i, WF_Matrix (f i))
  WF_Matrix (shift f j k).
#[export] Hint Resolve WF_shift: wf_db.

Lemma vkron_extend_r : n f,
  vkron n f f n = vkron (S n) f.

Lemma vkron_extend_l : n (f : nat Vector 2),
  ( i, WF_Matrix (f i))
  (f O) vkron n (shift f 1) = vkron (S n) f.

Lemma kron_n_f_to_vec : n (A : Square 2) f,
  n A × f_to_vec n f = vkron n (fun kA × f k ).

Lemma Mscale_vkron_distr_r : n x (f : nat Vector 2),
  vkron n (fun ix .* f i) = x ^ n .* vkron n f.

Lemma vkron_split : n i (f : nat Vector 2),
  ( j, WF_Matrix (f j))
  i < n
  vkron n f = (vkron i f) f i (vkron (n - 1 - i) (shift f (i + 1))).

Lemma vkron_eq : n (f f' : nat Vector 2),
  ( i, i < n f i = f' i) vkron n f = vkron n f'.


Lemma basis_vector_prepend_0 : n k,
  n 0 k < n
  ∣0⟩ basis_vector n k = basis_vector (2 × n) k.

Lemma basis_vector_prepend_1 : n k,
  n 0 k < n
  ∣1⟩ basis_vector n k = basis_vector (2 × n) (k + n).

Local Opaque Nat.mul Nat.div Nat.modulo.
Lemma basis_vector_append_0 : n k,
  n 0 k < n
  basis_vector n k ∣0⟩ = basis_vector (2 × n) (2 × k).

Lemma basis_vector_append_1 : n k,
  n 0 k < n
  basis_vector n k ∣1⟩ = basis_vector (2 × n) (2 × k + 1).
Local Transparent Nat.mul Nat.div Nat.modulo.

Lemma kron_n_0_is_0_vector : (n:nat), n ∣0⟩ = basis_vector (2 ^ n) O.

Lemma vkron_to_vsum1 : n (c : R),
  n > 0
  vkron n (fun k∣0⟩ .+ Cexp (c × 2 ^ (n - k - 1)) .* ∣1⟩) =
    big_sum (fun kCexp (c × INR k) .* basis_vector (2 ^ n) k) (2 ^ n).

Fixpoint product (x y : nat bool) n :=
  match n with
  | Ofalse
  | S n'xorb ((x n') && (y n')) (product x y n')
  end.

Lemma product_comm : f1 f2 n, product f1 f2 n = product f2 f1 n.

Lemma product_update_oob : f1 f2 n b dim, (dim n)%nat
  product f1 (update f2 n b) dim = product f1 f2 dim.

Lemma product_0 : f n, product (fun _ : natfalse) f n = false.

Lemma nat_to_funbool_0 : n, nat_to_funbool n 0 = (fun _false).

Lemma nat_to_funbool_1 : n, nat_to_funbool n 1 = (fun xx =? n - 1).

Local Open Scope R_scope.
Local Open Scope C_scope.
Local Opaque Nat.mul.
Lemma vkron_to_vsum2 : n (f : nat bool),
  (n > 0)%nat
  vkron n (fun k∣0⟩ .+ (-1) ^ f k .* ∣1⟩) =
    big_sum
      (fun k(-1) ^ (product f (nat_to_funbool n k) n) .* basis_vector (2 ^ n) k) (2^n).
Local Transparent Nat.mul.

Lemma H_spec :
   b : bool, hadamard × b = / 2 .* ( 0 .+ (-1)^b .* 1 ).

Lemma H_kron_n_spec : n x, (n > 0)%nat
  n hadamard × f_to_vec n x =
    /√(2 ^ n) .* big_sum (fun k(-1) ^ (product x (nat_to_funbool n k) n) .* basis_vector (2 ^ n) k) (2 ^ n).

Lemma H0_kron_n_spec_alt : n, (n > 0)%nat
  n hadamard × n ∣0⟩ =
    /√(2 ^ n) .* big_sum (fun kbasis_vector (2 ^ n) k) (2 ^ n).