1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
| #include <iostream> #include <algorithm> #include <cstring> #include <cstdio> #include <vector> #include <cmath> #include <vector>
using namespace std;
typedef long long ll;
const int N = 3000010; const ll mod = 998244353;
struct node { int a, b; node() { } node(int ta, int tb) : a(ta), b(tb) { } };
int a, b, c, d, x, y; int px[N], cx[N], py[N], cy[N], mx, my; vector<node> v; vector<int> alls; ll mi[N];
void divide(int n, int *p, int *c, int &m) { m = 0; for (int i = 2; i <= sqrt(n); i++) { if (0 != n % i) continue; p[++m] = i; c[m] = 0; while (0 == n % i) { n /= i; c[m] += 1; } } if (n > 1) { p[++m] = n; c[m] = 1; } }
ll power(ll a, ll n) { ll res = 1; while (n) { if (n & 1) res = res * a % mod; a = a * a % mod; n >>= 1; } return res; }
int get_id(int x) { return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1; }
int main() { scanf("%d%d%d%d%d%d", &a, &b, &c, &d, &x, &y); divide(x, px, cx, mx); divide(y, py, cy, my); int ppx = 1, ppy = 1; while (ppx <= mx && ppy <= my) { if (px[ppx] == py[ppy]) { v.push_back(node(cx[ppx], cy[ppy])); alls.push_back(px[ppx]); ppx += 1; ppy += 1; } else if (px[ppx] < py[ppy]) ppx += 1; else ppy += 1; } for (int i = a; i <= b; i++) { for (int k = 0; k < v.size(); k++) { int mia = v[k].a, mib = v[k].b, id = get_id(alls[k]); int j = (i * mia + mib - 1) / mib; if (j > d) { ll t = 1ll * (c + d) * (d - c + 1) / 2; mi[id] = (mi[id] + t * mib) % (mod - 1); } else if (j < c) { ll t = d - c + 1; mi[id] = (mi[id] + t * mia * i) % (mod - 1); } else { ll ta = 1ll * (c + j - 1) * (j - c) / 2, tb = d - j + 1; mi[id] = (mi[id] + ta * mib) % (mod - 1); mi[id] = (mi[id] + tb * mia * i) % (mod - 1); } } } ll res = 1; for (int i = 1; i <= alls.size(); i++) { if (0 == mi[i]) continue; ll t = power(alls[i - 1], mi[i]); res = res * t % mod; } printf("%lld\n", res); return 0; }
|