QuantumLib.DiscreteProb

Require Import VectorStates Measurement.

This file describes some theory of discrete probability distributions. Its main feature is 'apply_u', a function to describe the output distribution of running a quantum circuit.

Definition Cmod2 (c : C) : R := fst c ^ 2 + snd c ^ 2.

Lemma Cmod2_ge_0 : c, 0 Cmod2 c.

Lemma Cmod2_Cmod_sqr : c, (Cmod2 c = (Cmod c)^2)%R.

Definition of probability distribution


We represent a (discrete) probability distribution over (0,n) using a length n list of real numbers. We support sampling from this distribution using the 'sample' function.

Definition sum_over_list (l : list R) := big_sum (fun inth i l 0) (length l).

Definition distribution (l : list R) :=
  Forall (fun x ⇒ 0 x) l sum_over_list l = 1.

Lemma sum_over_list_nil : sum_over_list [] = 0.

Lemma sum_over_list_cons : x l,
  sum_over_list (x :: l) = (x + sum_over_list l)%R.

Lemma sum_over_list_append : l1 l2,
  sum_over_list (l1 ++ l2) = (sum_over_list l1 + sum_over_list l2)%R.

Lemma sum_over_list_geq_0 : l,
  Forall (fun x ⇒ 0 x) l
  0 sum_over_list l.

Sample from a distribution


Choose an element from the distribution based on random number r ∈ (0,1).
Example: Say that the input list is l = (.2, .3, .4, .1) (which might correspond to the probabilities of measuring the outcomes 00, 01, 10, 11). Then this function will return:
  • 0 for r ∈ (0, .2)
  • 1 for r ∈ (.2, .5)
  • 2 for r ∈ (.5, .9)
  • 3 for r ∈ (.9, 1)
The probability of getting a particular outcome is the size of the intervals of r values that produce that outcome. (See r_interval below.)
Fixpoint sample (l : list R) (r : R) : nat :=
  match l with
  | nil ⇒ 0
  | x :: l'if Rlt_le_dec r x then 0 else S (sample l' (r-x))
  end.

Lemma sample_ub : l r, (sample l r length l)%nat.

Lemma sample_ub_lt : l r,
  0 r < sum_over_list l
  (sample l r < length l)%nat.

Lemma sample_lb : l r, (0 sample l r)%nat.

Lemma sample_max : l r,
    Forall (fun x ⇒ 0 x) l
    sum_over_list l r
    sample l r = length l.

Lemma sample_append_l : l1 l2 r,
    0 r
    r < sum_over_list l1
    sample (l1 ++ l2) r = sample l1 r.

Lemma sample_append_r : l1 l2 r,
    Forall (fun x ⇒ 0 x) l1
    Forall (fun x ⇒ 0 x) l2
    sum_over_list l1 r
    (sample (l1 ++ l2) r = (length l1) + sample l2 (r - sum_over_list l1))%nat.

Lemma sample_repeat_lb : m l r,
    0 r
    (m sample (repeat 0%R m ++ l) r)%nat.

Probability that a distribution satisfies a predicate (pr_outcome_sum)


Intuitively, the probability that an element satisfies boolean predicate f is the sum over all element for which f holds.
Definition pr_outcome_sum (l : list R) (f : nat bool) : R :=
  big_sum (fun iif f i then nth i l 0 else 0) (length l).

Lemma pr_outcome_sum_extend : x l f,
  pr_outcome_sum (x :: l) f
  = if f O
    then (x + pr_outcome_sum l (fun yf (S y)))%R
    else pr_outcome_sum l (fun yf (S y)).

Lemma pr_outcome_sum_append : l1 l2 f,
  pr_outcome_sum (l1 ++ l2) f
  = (pr_outcome_sum l1 f + pr_outcome_sum l2 (fun xf (length l1 + x)%nat))%R.

Lemma pr_outcome_sum_repeat_false : n f,
  pr_outcome_sum (repeat 0 n) f = 0.

Definition pr_outcome_sum_extend' :
   l f a,
    (pr_outcome_sum (a :: l) f = (if (f O) then a else 0) + pr_outcome_sum l (fun if (S i)))%R.

Lemma pr_outcome_sum_replace_f : l f1 f2,
  ( x, (x < length l)%nat f1 x = f2 x)
  pr_outcome_sum l f1 = pr_outcome_sum l f2.

Lemma pr_outcome_sum_false : l f,
  ( i, (i < length l)%nat f i = false)
  pr_outcome_sum l f = 0.

Lemma pr_outcome_sum_true : l f,
  ( i, (i < length l)%nat f i = true)
  pr_outcome_sum l f = sum_over_list l.

Lemma pr_outcome_sum_negb : l f,
  pr_outcome_sum l f = (sum_over_list l - pr_outcome_sum l (fun xnegb (f x)))%R.

Lemma pr_outcome_sum_orb : l f1 f2,
  Forall (fun x ⇒ 0 x) l
  pr_outcome_sum l f1 pr_outcome_sum l (fun rndf1 rnd || f2 rnd).

Lemma pr_outcome_sum_implies : l f1 f2,
  Forall (fun x ⇒ 0 x) l
  ( x, f1 x = true f2 x = true)
  (pr_outcome_sum l f1 pr_outcome_sum l f2)%R.

Lemma pr_outcome_sum_ge_0 :
   l f, Forall (fun x ⇒ 0 x) l 0 pr_outcome_sum l f.

Probability that a distribution satisfies a predicate (pr_P)


Mathematically, the probability that an element satisifes a (not necessarily boolean) predicate is the size of the range of r-values for which the element returned from 'sample' satisfies the predicate.
Inductive interval_sum (P : R Prop) (rl rr : R) : R Prop :=
| SingleInterval : r1 r2, rl r1 r2 r2 rr
    ( r, r1 < r < r2 P r)
    ( r, rl < r < r1 ¬ P r)
    ( r, r2 < r < rr ¬ P r)
    interval_sum P rl rr (r2 - r1)%R

| CombineIntervals : rm r1 r2, rl rm rr
    interval_sum P rl rm r1
    interval_sum P rm rr r2
    interval_sum P rl rr (r1 + r2).

Lemma interval_sum_shift :
   P rl rr r a,
    interval_sum P rl rr r
    interval_sum (fun xP (x - a)%R) (rl + a)%R (rr + a)%R r.

Lemma interval_sum_same :
   P1 P2 rl rr r,
    interval_sum P1 rl rr r
    ( x, rl x < rr (P1 x P2 x))
    interval_sum P2 rl rr r.

Lemma interval_sum_shift_alt :
   P rl rr r a,
    interval_sum (fun xP (x + a)%R) (rl - a)%R (rr - a)%R r
    interval_sum P rl rr r.

Lemma interval_sum_gt_0 : P rl rr r, interval_sum P rl rr r r 0.

Lemma interval_sum_break :
   P rl rm rr r,
    interval_sum P rl rr r
    rl rm rr
     r1 r2 : R, interval_sum P rl rm r1 interval_sum P rm rr r2 (r = r1 + r2)%R.

Lemma interval_sum_unique : P rl rr r1 r2,
    interval_sum P rl rr r1
    interval_sum P rl rr r2
    r1 = r2.

Mathematical measure of P on the interval (0,1)
Definition pr_P P r := interval_sum P 0%R 1%R r.

Lemma pr_P_same :
   P1 P2 r,
    ( rnd, 0 rnd < 1 P1 rnd P2 rnd)
    pr_P P1 r
    pr_P P2 r.

Lemma pr_outcome_sum_eq_aux' : (l : list R) (f : nat bool) r,
    Forall (fun x ⇒ 0 x) l
    sum_over_list l = r
    interval_sum (fun rndf (sample l rnd) = true) 0 r (pr_outcome_sum l f).

The pr_outcome_sum and pr_P definitions of probability are consistent.
Lemma pr_outcome_sum_eq_aux : (l : list R) (f : nat bool),
    distribution l
    pr_P (fun rndf (sample l rnd) = true) (pr_outcome_sum l f).

Lemma pr_outcome_sum_leq_exists : l f r,
  distribution l
  pr_outcome_sum l f r
   r0, (0 r0 r)%R pr_P (fun rndf (sample l rnd) = true) r0.

Lemma pr_P_unique : P r1 r2,
    pr_P P r1
    pr_P P r2
    r1 = r2.

Lemma pr_outcome_sum_eq : f l r,
  distribution l
  pr_outcome_sum l f = r pr_P (fun rndf (sample l rnd) = true) r.

Distribution created by running a quantum program


Given our definition of sample, we can define a function to apply a quantum program and return the result of measuring all qubits.
rnd is a random input in (0,1).
Definition apply_u {dim} (u : Square (2 ^ dim)) : list R :=
  let v := u × basis_vector (2^dim) 0 in
  map Cmod2 (vec_to_list v).

Lemma pos_Cmod2_list :
   l, Forall (fun x ⇒ 0 x) (map Cmod2 l).

Local Opaque big_sum.
Lemma sum_over_list_Cmod2_vec_to_list' : d x l,
  sum_over_list (map Cmod2 (@vec_to_list' x d l)) =
    big_sum (fun i : nat ⇒ (Cmod (l (i + x - d)%nat O) ^ 2)%R) d.
Local Transparent big_sum.

Lemma sum_over_list_Cmod2_vec_to_list :
   d (l : Vector d),
    sum_over_list (map Cmod2 (vec_to_list l)) =
      big_sum (fun i : nat ⇒ (Cmod (l i O) ^ 2)%R) d.

Lemma distribution_apply_u : {dim} u,
    WF_Unitary u
    distribution (@apply_u dim u).

Lemma length_apply_u : n (u : Square (2 ^ n)),
  length (apply_u u) = (2 ^ n)%nat.

Lemma nth_apply_u_probability_of_outcome : n (u : Square (2 ^ n)) x,
  (x < 2 ^ n)%nat
  WF_Matrix u
  nth x (apply_u u) 0
    = probability_of_outcome
        (basis_vector (2^n) x)
        (u × basis_vector (2^n) 0).

Uniform distribution


Uniform sampling in the range (lower, upper)
Definition uniform (lower upper : nat) : list R :=
  repeat 0 lower ++ repeat (1/ INR (upper - lower))%R (upper - lower).

Lemma repeat_gt0 : m r, 0 r Forall (fun x ⇒ 0 x) (repeat r m).

Lemma sum_over_list_repeat : m x, (sum_over_list (repeat x m) = INR m × x)%R.

Lemma sample_uniform : l u r,
  (l < u)%nat 0 r < 1 (l sample (uniform l u) r < u)%nat.

Lemma distribution_uniform : l r,
  (l < r)%nat
  distribution (uniform l r).

Lemma length_uniform : l r, (l r)%nat (length (uniform l r) = r)%nat.

Joint distribution


Fixpoint scale r l :=
  match l with
  | nilnil
  | h :: t ⇒ (r × h)%R :: scale r t
  end.

Combine distributions l1 and l2, where l2 may depend on the value of l1
Fixpoint join' l1 l2 n :=
  match n with
  | Onil
  | S n'join' l1 l2 n' ++ scale (nth n' l1 0) (l2 n')
  end.
Definition join l1 l2 := join' l1 l2 (length l1).

Given a nat consisting of (n+m) bits, extract the first n or last m. Example application: when sampling from (join l1 l2) where |l1|=n and |l2|=m, you can use fst and snd to split the result.
 Definition fst (m x : nat) := (x / 2 ^ m)%nat.
 Definition snd (m x : nat) := (x mod 2 ^ m)%nat.

Lemma fst_0 : m, fst m 0 = O.

Lemma fst_plus : m x, fst m (2 ^ m + x) = S (fst m x).

Lemma fst_small : m x, (x < 2 ^ m)%nat fst m x = O.

Lemma snd_0 : m, snd m 0 = O.

Lemma snd_small : m x, (x < 2 ^ m)%nat snd m x = x.

Lemma snd_plus : m x, snd m (2 ^ m + x) = snd m x.

Lemma simplify_fst : n x y,
  (y < 2 ^ n)%nat
  fst n (x × 2 ^ n + y) = x.

Lemma simplify_snd : n x y,
  (y < 2 ^ n)%nat
  snd n (x × 2 ^ n + y) = y.

Lemma sum_over_list_scale : x l,
  sum_over_list (scale x l) = (x × sum_over_list l)%R.

Lemma sum_over_list_firstn : n l, (n < length l)%nat
  sum_over_list (firstn (S n) l) = (sum_over_list (firstn n l) + nth n l 0)%R.

Lemma Forall_scale_geq : a l,
  (0 a)%R
  Forall (fun x : R ⇒ 0 x) l
  Forall (fun x : R ⇒ 0 x) (scale a l).

Lemma join_geq_0 : l1 l2,
  Forall (fun x : R ⇒ 0 x) l1
  ( i, (i < length l1)%nat Forall (fun x : R ⇒ 0 x) (l2 i))
  Forall (fun x : R ⇒ 0 x) (join l1 l2).

Lemma distribution_join : l1 l2,
  distribution l1
  ( i, (i < length l1)%nat distribution (l2 i))
  distribution (join l1 l2).

Lemma join_cons : x l1 l2,
  join (x :: l1) l2 = scale x (l2 O) ++ join l1 (shift l2 1).

Lemma length_scale : a l, length (scale a l) = length l.

Sampling from (join l1 l2) where |l1|=n and |l2|=m and taking the first n bits of the result is the same as sampling directly from l1.
Lemma fst_sample_join : l1 l2 rnd m,
  0 rnd
  Forall (fun x : R ⇒ 0 x) l1
  ( k, length (l2 k) = 2 ^ m)%nat
  ( k, (k < length l1)%nat distribution (l2 k))
  fst m (sample (join l1 l2) rnd) = sample l1 rnd.

Lemma sample_scale : a l rnd,
  a > 0 sample (scale a l) rnd = sample l (rnd / a).

Definition compute_new_rnd rnd l o : R :=
  (rnd - sum_over_list (firstn o l)) / nth o l 0.

Sampling from (join l1 l2) where |l1|=n and |l2|=m and taking the last m bits of the result is the same as sampling from l1 and, based on the the outcome o, sampling from (l2 o).
Lemma snd_sample_join : l1 l2 rnd m,
    0 rnd < sum_over_list l1
    Forall (fun x : R ⇒ 0 x) l1
    ( k, length (l2 k) = (2 ^ m)%nat)
    ( k, (k < length l1)%nat distribution (l2 k))
    let o := sample l1 rnd in
    let rnd' := compute_new_rnd rnd l1 o in
    snd m (sample (join l1 l2) rnd) = sample (l2 o) rnd'.

Lemma pr_outcome_sum_scale : a l f,
  pr_outcome_sum (scale a l) f = (a × pr_outcome_sum l f)%R.

Lemma length_join' : x l1 l2 m,
  ( k, (k < x)%nat length (l2 k) = m)
  (length (join' l1 l2 x) = x × m)%nat.

Lemma pr_outcome_sum_firstn : n l f,
  (n < length l)%nat
  pr_outcome_sum (firstn (S n) l) f =
    ((if f n then nth n l 0 else 0) + pr_outcome_sum (firstn n l) f)%R.
  Local Opaque firstn.

If the probability of f1 in distr1(=l1) is r1 and the probability of f2 in distr2(=l2) is r2, then the probability of f1&f2 in (join l1 l2) is r1 * r2.
Local Transparent firstn.
Lemma pr_outcome_sum_join_geq : l1 l2 f1 f2 r1 r2 n,
  distribution l1
  (0 r2)%R
  pr_outcome_sum l1 f1 r1
  ( i, (i < length l1)%nat
        length (l2 i) = (2 ^ n)%nat
        pr_outcome_sum (l2 i) (f2 i) r2)
  let f1f2 z := (let x := fst n z in
                 let y := snd n z in
                 f1 x && f2 x y) in
  pr_outcome_sum (join l1 l2) f1f2 (r1 × r2)%R.

Lemma rewrite_pr_outcome_sum : n k (u : Square (2 ^ (n + k))) f,
  WF_Matrix u
  pr_outcome_sum (apply_u u) (fun xf (fst k x))
  = big_sum (fun x ⇒ ((if f x then 1 else 0) ×
                         prob_partial_meas (basis_vector (2 ^ n) x)
                           (u × basis_vector (2 ^ (n + k)) 0))%R) (2 ^ n).

Repeat independent runs


rnds : source of randomness for sampling niter : max number of iterations body : operation to iterate
Fixpoint iterate {A} (rnds : list R) (body : R option A) :=
  match rnds with
  | nilNone
  | rnd :: rnds'
      match body rnd with
      | Some vSome v
      | Noneiterate rnds' body
      end
  end.

Lemma iterate_replace_body : {A} rnds (body body' : R option A),
  Forall (fun r : R ⇒ 0 r < 1) rnds
  ( r, 0 r < 1 body r = body' r)
  iterate rnds body = iterate rnds body'.

Inductive pr_Ps : ((list R) Prop) nat R Prop :=
| pr_Ps_base : (Ps : (list R) Prop), Ps nil pr_Ps Ps O 1
| pr_Ps_rec : Ps i r1 P r2,
    pr_Ps Ps i r1
    pr_P P r2
    ( rnd rnds, 0 rnd < 1
                 Forall (fun r : R ⇒ 0 r < 1) rnds
                 Ps (rnd :: rnds) Ps rnds P rnd)
    pr_Ps Ps (S i) (r1 × r2).

Lemma pr_Ps_same :
   i Ps1 Ps2 r,
    ( rnds, Forall (fun r : R ⇒ 0 r < 1) rnds Ps1 rnds Ps2 rnds)
    pr_Ps Ps1 i r
    pr_Ps Ps2 i r.

Lemma pr_Ps_nil :
   i Ps r,
    pr_Ps Ps i r
    Ps nil.

Lemma pr_Ps_unique : Ps i r1 r2,
  pr_Ps Ps i r1
  pr_Ps Ps i r2
  r1 = r2.

Definition isNone {A} (o : option A) := match o with Nonetrue | _false end.

Lemma pr_iterate_None :
   {A} n (body : R option A) r,
    pr_P (fun rndisNone (body rnd) = true) r
    pr_Ps (fun rndsisNone (iterate rnds body) = true) n (r ^ n)%R.