QuantumLib.Summation

Require Import List.
Require Export Prelim.

Declare Scope group_scope.
Delimit Scope group_scope with G.

Open Scope group_scope.

Class Monoid G :=
  { Gzero : G
  ; Gplus : G G G
  ; Gplus_0_l : g, Gplus Gzero g = g
  ; Gplus_0_r : g, Gplus g Gzero = g
  ; Gplus_assoc : g h i, Gplus g (Gplus h i) = Gplus (Gplus g h) i
  }.

Infix "+" := Gplus : group_scope.
Notation "0" := Gzero : group_scope.


Class Group G `{Monoid G} :=
  { Gopp : G G
  ; Gopp_l : g, (Gopp g) + g = 0
  ; Gopp_r : g, g + (Gopp g) = 0
  }.

Class Comm_Group G `{Group G} :=
  { Gplus_comm : a b, Gplus a b = Gplus b a }.

Definition Gminus {G} `{Group G} (g1 g2 : G) := g1 + (Gopp g2).

Notation "- x" := (Gopp x) : group_scope.
Infix "-" := Gminus : group_scope.

Class Ring R `{Comm_Group R} :=
  { Gone : R
  ; Gmult : R R R
  ; Gmult_1_l : a, Gmult Gone a = a
  ; Gmult_1_r : a, Gmult a Gone = a
  ; Gmult_assoc : a b c, Gmult a (Gmult b c) = Gmult (Gmult a b) c
  ; Gmult_plus_distr_l : a b c, Gmult c (a + b) = (Gmult c a) + (Gmult c b)
  ; Gmult_plus_distr_r : a b c, Gmult (a + b) c = (Gmult a c) + (Gmult b c)
  }.

Class Comm_Ring R `{Ring R} :=
  { Gmult_comm : a b, Gmult a b = Gmult b a }.

Infix "×" := Gmult : group_scope.
Notation "1" := Gone : group_scope.

Class Field F `{Comm_Ring F} :=
  { Ginv : F F
  ; G1_neq_0 : 1 0
  ; Ginv_r : f, f 0 f × (Ginv f) = 1 }.

Definition Gdiv {G} `{Field G} (g1 g2 : G) := Gmult g1 (Ginv g2).

Notation "/ x" := (Ginv x) : group_scope.
Infix "/" := Gdiv : group_scope.

Class Vector_Space V F `{Comm_Group V} `{Field F} :=
  { Vscale : F V V
  ; Vscale_1 : v, Vscale 1 v = v
  ; Vscale_dist : a u v, Vscale a (u + v) = Vscale a u + Vscale a v
  ; Vscale_assoc : a b v, Vscale a (Vscale b v) = Vscale (a × b) v
  }.

Infix "⋅" := Vscale (at level 40) : group_scope.


Lemma G_ring_theory : {R} `{Comm_Ring R}, ring_theory 0 1 Gplus Gmult Gminus Gopp eq.

Lemma G_field_theory : {F} `{Field F}, field_theory 0 1 Gplus Gmult Gminus Gopp Gdiv Ginv eq.


Lemma Gplus_cancel_l : {G} `{Group G} (g h a : G),
  a + g = a + h g = h.

Lemma Gplus_cancel_r : {G} `{Group G} (g h a : G),
  g + a = h + a g = h.

Lemma Gopp_unique_l : {G} `{Group G} (g h : G),
  h + g = 0 h = Gopp g.

Lemma Gopp_unique_r : {G} `{Group G} (g h : G),
  g + h = 0 h = - g.

Lemma Gopp_involutive : {G} `{Group G} (g : G),
  - (- g) = g.

Lemma Gopp_plus_distr : {G} `{Group G} (g h : G),
  - (g + h) = - h + - g.

Lemma Vscale_zero : {V F} `{Vector_Space V F} (c : F),
  c 0 = 0.

Lemma Gmult_0_l : {R} `{Ring R} (r : R),
  0 × r = 0.

Lemma Gmult_0_r : {R} `{Ring R} (r : R),
  r × 0 = 0.

Lemma Ginv_l : {F} `{Field F} (f : F), f 0 (Ginv f) × f = 1.

Lemma Gmult_cancel_l : {F} `{Field F} (g h a : F),
  a 0 a × g = a × h g = h.

Lemma Gmult_cancel_r : {F} `{Field F} (g h a : F),
  a 0 g × a = h × a g = h.

Lemma Gmult_neq_0 : {F} `{Field F} (a b : F), a 0 b 0 a × b 0.

Lemma Ginv_mult_distr : {F} `{Field F} (a b : F),
    a 0 b 0
    / (a × b) = / a × / b.


Program Instance nat_is_monoid : Monoid nat :=
  { Gzero := 0
  ; Gplus := plus
  }.
Solve All Obligations with program_simpl; try lia.

Summation functions

Fixpoint times_n {G} `{Monoid G} (g : G) (n : nat) :=
  match n with
  | 0 ⇒ 0
  | S n'g + times_n g n'
  end.

Fixpoint G_big_plus {G} `{Monoid G} (gs : list G) : G :=
  match gs with
  | nil ⇒ 0
  | g :: gs'g + (G_big_plus gs')
  end.

Fixpoint G_big_mult {R} `{Ring R} (rs : list R) : R :=
  match rs with
  | nil ⇒ 1
  | r :: rs'r × (G_big_mult rs')
  end.

sum to n exclusive
Fixpoint big_sum {G : Type} `{Monoid G} (f : nat G) (n : nat) : G :=
  match n with
  | 0 ⇒ 0
  | S n'(big_sum f n') + (f n')
  end.

Lemma big_sum_0 : {G} `{Monoid G} f n,
    ( x, f x = 0) big_sum f n = 0.

Lemma big_sum_eq : {G} `{Monoid G} f g n, f = g big_sum f n = big_sum g n.

Lemma big_sum_0_bounded : {G} `{Monoid G} f n,
    ( x, (x < n)%nat f x = 0) big_sum f n = 0.

Lemma big_sum_eq_bounded : {G} `{Monoid G} f g n,
    ( x, (x < n)%nat f x = g x) big_sum f n = big_sum g n.

Lemma big_sum_shift : {G} `{Monoid G} n (f : nat G),
  big_sum f (S n) = f O + big_sum (fun xf (S x)) n.

Lemma big_sum_constant : {G} `{Monoid G} g n,
  big_sum (fun _g) n = times_n g n.

Lemma big_plus_constant : {G} `{Monoid G} (l : list G) (g : G),
  ( h, In h l h = g) G_big_plus l = (times_n g (length l))%nat.

Lemma big_plus_app : {G} `{Monoid G} (l1 l2 : list G),
  G_big_plus l1 + G_big_plus l2 = G_big_plus (l1 ++ l2).

Lemma big_plus_inv : {G} `{Group G} (l : list G),
  - (G_big_plus l) = G_big_plus (map Gopp (rev l)).

Lemma times_n_nat : n k,
  times_n k n = (k × n)%nat.

Lemma big_sum_plus : {G} `{Comm_Group G} f g n,
    big_sum (fun xf x + g x) n = big_sum f n + big_sum g n.

Lemma big_sum_scale_l : {G} {V} `{Vector_Space G V} c f n,
    c big_sum f n = big_sum (fun xc f x) n.

Lemma big_sum_mult_l : {R} `{Ring R} c f n,
    c × big_sum f n = big_sum (fun xc × f x) n.

Lemma big_sum_mult_r : {R} `{Ring R} c f n,
    big_sum f n × c = big_sum (fun xf x × c) n.

Lemma big_sum_func_distr : {G1 G2} `{Group G1} `{Group G2} f (g : G1 G2) n,
    ( a b, g (a + b) = g a + g b) g (big_sum f n) = big_sum (fun xg (f x)) n.

Lemma big_sum_prop_distr : {G} `{Monoid G} f (p : G Prop) n,
    ( a b, p a p b p (a + b)) p 0 ( i, i < n p (f i))
    p (big_sum f n).

Lemma big_sum_extend_r : {G} `{Monoid G} n f,
    big_sum f n + f n = big_sum f (S n).

Lemma big_sum_extend_l : {G} `{Monoid G} n f,
    f O + big_sum (fun xf (S x)) n = big_sum f (S n).

Lemma big_sum_unique : {G} `{Monoid G} k (f : nat G) n,
  ( x, (x < n)%nat f x = k ( x', x' < n x x' f x' = 0))
  big_sum f n = k.

Lemma big_sum_sum : {G} `{Monoid G} m n f,
  big_sum f (m + n) = big_sum f m + big_sum (fun xf (m + x)%nat) n.

Lemma big_sum_twice : {G} `{Monoid G} n f,
  big_sum f (2 × n) = big_sum f n + big_sum (fun xf (n + x)%nat) n.

Lemma big_sum_product : {G} `{Ring G} m n f g,
  n O
  big_sum f m × big_sum g n = big_sum (fun xf (x / n)%nat × g (x mod n)%nat) (m × n).

Local Open Scope nat_scope.

Lemma big_sum_double_sum : {G} `{Monoid G} (f : nat nat G) (n m : nat),
    big_sum (fun x ⇒ (big_sum (fun yf x y) n)) m = big_sum (fun zf (z / n) (z mod n)) (n × m).

Local Close Scope nat_scope.

Lemma big_sum_extend_double : {G} `{Ring G} (f : nat nat G) (n m : nat),
  big_sum (fun ibig_sum (fun jf i j) (S m)) (S n) =
  (big_sum (fun ibig_sum (fun jf i j) m) n) + (big_sum (fun jf n j) m) +
                      (big_sum (fun if i m) n) + f n m.

Lemma nested_big_sum : {G} `{Monoid G} m n f,
  big_sum f (2 ^ (m + n))
    = big_sum (fun xbig_sum (fun yf (x × 2 ^ n + y)%nat) (2 ^ n)) (2 ^ m).

Lemma big_sum_swap_order : {G} `{Comm_Group G} (f : nat nat G) m n,
  big_sum (fun jbig_sum (fun if j i) m) n =
    big_sum (fun ibig_sum (fun jf j i) n) m.

Lemma big_sum_diagonal : {G} `{Monoid G} (f : nat nat G) n,
    ( i j, (i < n)%nat (j < n)%nat (i j)%nat f i j = 0)
    big_sum (fun ibig_sum (fun jf i j) n) n = big_sum (fun if i i) n.

Lemma Nsum_le : n f g,
  ( x, x < n f x g x)%nat
  (big_sum f n big_sum g n)%nat.




Inductive mexp {G} : Type :=
| Ident : mexp
| Var : G mexp
| Op : mexp mexp mexp.

Fixpoint mdenote {G} `{Monoid G} (me : mexp) : G :=
  match me with
  | Ident ⇒ 0
  | Var vv
  | Op me1 me2mdenote me1 + mdenote me2
  end.


Fixpoint flatten {G} `{Monoid G} (me : mexp) : list G :=
  match me with
  | Identnil
  | Var xx :: nil
  | Op me1 me2flatten me1 ++ flatten me2
  end.

Theorem flatten_correct : {G} `{Monoid G} me,
  mdenote me = G_big_plus (flatten me).

Theorem monoid_reflect : {G} `{Monoid G} me1 me2,
  G_big_plus (flatten me1) = G_big_plus (flatten me2)
  mdenote me1 = mdenote me2.

Ltac reify_mon me :=
  match me with
  | 0 ⇒ Ident
  | ?me1 + ?me2
      let r1 := reify_mon me1 in
      let r2 := reify_mon me2 in
      constr:(Op r1 r2)
  | _constr:(Var me)
  end.

Ltac solve_monoid :=
  match goal with
  | [ |- ?me1 = ?me2 ] ⇒
      let r1 := reify_mon me1 in
      let r2 := reify_mon me2 in
      change (mdenote r1 = mdenote r2);
      apply monoid_reflect; simpl;
      repeat (rewrite Gplus_0_l);
      repeat (rewrite Gplus_0_r);
      repeat (rewrite Gplus_assoc); try easy
  end.


Lemma test : {G} `{Monoid G} a b c d, a + b + c + d = a + (b + c) + d.

Inductive gexp {G} : Type :=
| Gident : gexp
| Gvar : G gexp
| Gop : gexp gexp gexp
| Gmin : gexp gexp.

Fixpoint gdenote {G} `{Group G} (ge : gexp) : G :=
  match ge with
  | Gident ⇒ 0
  | Gvar vv
  | Gop me1 me2gdenote me1 + gdenote me2
  | Gmin v- gdenote v
  end.

Fixpoint gflatten {G} `{Group G} (ge : gexp) : list G :=
  match ge with
  | Gidentnil
  | Gvar xx :: nil
  | Gop ge1 ge2gflatten ge1 ++ gflatten ge2
  | Gmin ge'map Gopp (rev (gflatten ge'))
  end.

Theorem gflatten_correct : {G} `{Group G} ge,
    gdenote ge = G_big_plus (gflatten ge).

Theorem group_reflect : {G} `{Group G} ge1 ge2,
  G_big_plus (gflatten ge1) = G_big_plus (gflatten ge2)
  gdenote ge1 = gdenote ge2.

Lemma big_plus_reduce : {G} `{Group G} a l,
  G_big_plus (a :: l) = a + G_big_plus l.

Lemma big_plus_inv_r : {G} `{Group G} a l,
  G_big_plus (a :: -a :: l) = G_big_plus l.

Lemma big_plus_inv_l : {G} `{Group G} a l,
  G_big_plus (-a :: a :: l) = G_big_plus l.

Ltac reify_grp ge :=
  match ge with
  | 0 ⇒ Gident
  | ?ge1 + ?ge2
      let r1 := reify_grp ge1 in
      let r2 := reify_grp ge2 in
      constr:(Gop r1 r2)
  | ?ge1 - ?ge2
      let r1 := reify_grp ge1 in
      let r2 := reify_grp ge2 in
      constr:(Gop r1 (Gmin r2))
  | -?ge
      let r := reify_grp ge in
      constr:(Gmin r)
  | _constr:(Gvar ge)
  end.

Ltac solve_group :=
  unfold Gminus;
  match goal with
  | [ |- ?me1 = ?me2 ] ⇒
      let r1 := reify_grp me1 in
      let r2 := reify_grp me2 in
      change (gdenote r1 = gdenote r2);
      apply group_reflect; simpl gflatten;
      repeat (rewrite Gopp_involutive);
      repeat (try (rewrite big_plus_inv_r);
              try (rewrite big_plus_inv_l);
              try rewrite big_plus_reduce); simpl;
      repeat (rewrite Gplus_0_l); repeat (rewrite Gplus_0_r);
      repeat (rewrite Gplus_assoc); try easy
  end.

Lemma test2 : {G} `{Group G} a b c d, a + b + c + d - d = a + (b + c) + d - d.

Lemma test3 : {G} `{Group G} a b c, - (a + b + c) + a = 0 - c - b.