From dc5dc1da4cfa10647d13f393b6685f962b9b9faa Mon Sep 17 00:00:00 2001 From: Xavier Leroy Date: Sun, 7 Jul 2019 17:29:12 +0200 Subject: Add FMA (fused multiply-add) Cherry-pick of the following commit on upstream Flocq: https://gitlab.inria.fr/flocq/flocq/commit/28cc6ee3a278878f3df002aab64a6b93e9412d34 --- flocq/IEEE754/Binary.v | 121 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) (limited to 'flocq') diff --git a/flocq/IEEE754/Binary.v b/flocq/IEEE754/Binary.v index 0ec3a297..ac38c761 100644 --- a/flocq/IEEE754/Binary.v +++ b/flocq/IEEE754/Binary.v @@ -1839,6 +1839,127 @@ now rewrite <- cond_Zopp_negb. now destruct y as [ | | | ]. Qed. +(** Fused Multiply-Add *) + +Definition Bfma_szero m (x y z: binary_float) : bool := + let s_xy := xorb (Bsign x) (Bsign y) in (* sign of product x*y *) + if Bool.eqb s_xy (Bsign z) then s_xy + else match m with mode_DN => true | _ => false end. + +Definition Bfma fma_nan m (x y z: binary_float) := + match x, y with + | B754_nan _ _ _, _ | _, B754_nan _ _ _ + | B754_infinity _, B754_zero _ + | B754_zero _, B754_infinity _ => + (* Multiplication produces NaN *) + build_nan (fma_nan x y z) + | B754_infinity sx, B754_infinity sy + | B754_infinity sx, B754_finite sy _ _ _ + | B754_finite sx _ _ _, B754_infinity sy => + let s := xorb sx sy in + (* Multiplication produces infinity with sign [s] *) + match z with + | B754_nan _ _ _ => build_nan (fma_nan x y z) + | B754_infinity sz => + if Bool.eqb s sz then z else build_nan (fma_nan x y z) + | _ => B754_infinity s + end + | B754_finite sx _ _ _, B754_zero sy + | B754_zero sx, B754_finite sy _ _ _ + | B754_zero sx, B754_zero sy => + (* Multiplication produces zero *) + match z with + | B754_nan _ _ _ => build_nan (fma_nan x y z) + | B754_zero _ => B754_zero (Bfma_szero m x y z) + | _ => z + end + | B754_finite sx mx ex _, B754_finite sy my ey _ => + (* Multiplication produces a finite, non-zero result *) + match z with + | B754_nan _ _ _ => build_nan (fma_nan x y z) + | B754_infinity sz => z + | B754_zero _ => + let X := Float radix2 (cond_Zopp sx (Zpos mx)) ex in + let Y := Float radix2 (cond_Zopp sy (Zpos my)) ey in + let '(Float _ mr er) := Fmult X Y in + binary_normalize m mr er (Bfma_szero m x y z) + | B754_finite sz mz ez _ => + let X := Float radix2 (cond_Zopp sx (Zpos mx)) ex in + let Y := Float radix2 (cond_Zopp sy (Zpos my)) ey in + let Z := Float radix2 (cond_Zopp sz (Zpos mz)) ez in + let '(Float _ mr er) := Fplus (Fmult X Y) Z in + binary_normalize m mr er (Bfma_szero m x y z) + end + end. + +Theorem Bfma_correct: + forall fma_nan m x y z, + let res := (B2R x * B2R y + B2R z)%R in + is_finite x = true -> + is_finite y = true -> + is_finite z = true -> + if Rlt_bool (Rabs (round radix2 fexp (round_mode m) res)) (bpow radix2 emax) then + B2R (Bfma fma_nan m x y z) = round radix2 fexp (round_mode m) res /\ + is_finite (Bfma fma_nan m x y z) = true /\ + Bsign (Bfma fma_nan m x y z) = + match Rcompare res 0 with + | Eq => Bfma_szero m x y z + | Lt => true + | Gt => false + end + else + B2FF (Bfma fma_nan m x y z) = binary_overflow m (Rlt_bool res 0). +Proof. + intros. pattern (Bfma fma_nan m x y z). + match goal with |- ?p ?x => set (PROP := p) end. + set (szero := Bfma_szero m x y z). + assert (BINORM: forall mr er, F2R (Float radix2 mr er) = res -> + PROP (binary_normalize m mr er szero)). + { intros mr er E. + specialize (binary_normalize_correct m mr er szero). + change (FLT_exp (3 - emax - prec) prec) with fexp. rewrite E. tauto. + } + set (add_zero := + match z with + | B754_nan _ _ _ => build_nan (fma_nan x y z) + | B754_zero sz => B754_zero szero + | _ => z + end). + assert (ADDZERO: B2R x = 0%R \/ B2R y = 0%R -> PROP add_zero). + { + intros Z. + assert (RES: res = B2R z). + { unfold res. destruct Z as [E|E]; rewrite E, ?Rmult_0_l, ?Rmult_0_r, Rplus_0_l; auto. } + unfold PROP, add_zero; destruct z as [ sz | sz | sz plz | sz mz ez Bz]; try discriminate. + - simpl in RES; rewrite RES; rewrite round_0 by apply valid_rnd_round_mode. + rewrite Rlt_bool_true. split. reflexivity. split. reflexivity. + rewrite Rcompare_Eq by auto. reflexivity. + rewrite Rabs_R0; apply bpow_gt_0. + - rewrite RES, round_generic, Rlt_bool_true. + split. reflexivity. split. reflexivity. + unfold B2R. destruct sz. + rewrite Rcompare_Lt. auto. apply F2R_lt_0. reflexivity. + rewrite Rcompare_Gt. auto. apply F2R_gt_0. reflexivity. + apply abs_B2R_lt_emax. apply valid_rnd_round_mode. apply generic_format_B2R. + } + destruct x as [ sx | sx | sx plx | sx mx ex Bx]; + destruct y as [ sy | sy | sy ply | sy my ey By]; + try discriminate. +- apply ADDZERO; auto. +- apply ADDZERO; auto. +- apply ADDZERO; auto. +- destruct z as [ sz | sz | sz plz | sz mz ez Bz]; try discriminate; unfold Bfma. ++ set (X := Float radix2 (cond_Zopp sx (Zpos mx)) ex). + set (Y := Float radix2 (cond_Zopp sy (Zpos my)) ey). + destruct (Fmult X Y) as [mr er] eqn:FRES. + apply BINORM. unfold res. rewrite <- FRES, F2R_mult, Rplus_0_r. auto. ++ set (X := Float radix2 (cond_Zopp sx (Zpos mx)) ex). + set (Y := Float radix2 (cond_Zopp sy (Zpos my)) ey). + set (Z := Float radix2 (cond_Zopp sz (Zpos mz)) ez). + destruct (Fplus (Fmult X Y) Z) as [mr er] eqn:FRES. + apply BINORM. unfold res. rewrite <- FRES, F2R_plus, F2R_mult. auto. +Qed. + (** Division *) Definition Fdiv_core_binary m1 e1 m2 e2 := -- cgit