import numpy as np
import pandas as pd
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.nn as nn
import jax.numpy as jnp
import jax.random as random
df = pd.read_csv("/Users/jtannen/Documents/sixtysix/div_gov_sen.csv")
def expit(x):
return 1 / (1 + jnp.exp(-x))
def model(votes_senate, votes_gov, pvote_mastriano):
# N_div * N_sen * N_gov
beta = numpyro.sample("beta", dist.Normal(jnp.zeros([3, 3]), 1))
alpha = numpyro.sample("alpha", dist.Normal(jnp.zeros([3, 3]), 1))
probs_raw = expit(alpha[np.newaxis,:,:] + beta[np.newaxis,:,:] * pvote_mastriano.reshape(-1,1,1))
probs = numpyro.deterministic("probs", probs_raw / probs_raw.sum(axis=1, keepdims=True))
means = jnp.einsum("dsg,dg->ds", probs_raw, votes_gov)
variances = jnp.einsum("dsg,dg->ds", probs_raw * (1-probs_raw), votes_gov)
obs = numpyro.sample("obs", dist.Normal(means, jnp.sqrt(variances)), obs=votes_senate)
loglik = numpyro.deterministic("loglik", dist.Normal(means, jnp.sqrt(variances)).log_prob(votes_senate).sum())
np.random.seed(215)
votes_gov=df[["votes_Mastriano", "votes_Shapiro", "votes_GOVERNOR - Other Candidate"]].values
votes_senate=df[["votes_Oz", "votes_Fetterman", "votes_UNITED STATES SENATOR - Other Candidate"]].values
# Run the model
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
mcmc.run(
rng_key=random.PRNGKey(215),
votes_senate=votes_senate,
votes_gov=votes_gov,
pvote_mastriano=df.pvote_Mastriano.values,
)
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:16<00:00, 90.80it/s, 63 steps of size 8.54e-02. acc. prob=0.93]
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat alpha[0,0] 2.42 0.08 2.41 2.28 2.56 913.05 1.00 alpha[0,1] -4.30 0.03 -4.30 -4.34 -4.26 704.10 1.00 alpha[0,2] -2.77 0.39 -2.72 -3.35 -2.16 948.58 1.00 alpha[1,0] -1.45 0.08 -1.45 -1.59 -1.33 758.07 1.00 alpha[1,1] 3.67 0.02 3.67 3.63 3.71 750.10 1.00 alpha[1,2] 1.17 0.17 1.16 0.91 1.45 1010.79 1.00 alpha[2,0] -3.28 0.26 -3.26 -3.68 -2.83 597.04 1.00 alpha[2,1] -4.74 0.04 -4.75 -4.80 -4.68 910.47 1.00 alpha[2,2] -0.44 0.10 -0.44 -0.60 -0.27 734.86 1.00 beta[0,0] -3.65 0.18 -3.65 -3.92 -3.35 929.15 1.00 beta[0,1] 8.90 0.12 8.90 8.68 9.08 1223.30 1.00 beta[0,2] 1.95 1.07 1.95 0.30 3.79 1191.00 1.00 beta[1,0] 2.00 0.15 2.00 1.77 2.26 873.53 1.00 beta[1,1] -8.24 0.10 -8.24 -8.41 -8.07 1229.53 1.00 beta[1,2] -3.59 0.82 -3.60 -4.93 -2.24 1052.71 1.00 beta[2,0] -0.33 0.32 -0.34 -0.81 0.24 741.25 1.00 beta[2,1] -0.08 1.11 0.00 -1.87 1.69 669.75 1.00 beta[2,2] 0.63 0.54 0.64 -0.15 1.59 816.44 1.00 Number of divergences: 0
posterior_samples = mcmc.get_samples()
posterior_samples['loglik'][:10]
DeviceArray([-14386.096, -14382.528, -14388.564, -14380.695, -14387.931, -14394.938, -14383.229, -14390.517, -14390.294, -14391.021], dtype=float32)
from matplotlib import pyplot as plt
plt.plot(posterior_samples['loglik'])
[<matplotlib.lines.Line2D at 0x296a916a0>]
div_probs = posterior_samples["probs"].mean(axis=0)
div_probs
DeviceArray([[[0.7621246 , 0.04520724, 0.07020296], [0.20688044, 0.94599426, 0.5721061 ], [0.03099496, 0.0087985 , 0.357691 ]], [[0.76991457, 0.03632786, 0.06616237], [0.19885421, 0.9549252 , 0.5832254 ], [0.03123127, 0.00874721, 0.35061216]], [[0.7515514 , 0.05985679, 0.07585191], [0.2177366 , 0.93126076, 0.5568612 ], [0.0307121 , 0.00888242, 0.36728695]], ..., [[0.6305066 , 0.46803164, 0.14697033], [0.34001532, 0.5219423 , 0.3883959 ], [0.02947815, 0.01002602, 0.46463376]], [[0.62784714, 0.47937468, 0.14853886], [0.34267607, 0.5105869 , 0.38509282], [0.02947686, 0.01003836, 0.46636835]], [[0.64273256, 0.4151022 , 0.13971844], [0.32777435, 0.57493794, 0.4038754 ], [0.02949318, 0.00995998, 0.45640627]]], dtype=float32)
div_crosstabs = jnp.einsum("dsg,dg->dsg", div_probs, votes_gov)
div_crosstabs.sum(axis=2)
DeviceArray([[ 50.19893 , 291.287 , 5.51408 ], [ 49.644142, 359.64352 , 5.712438], [ 86.56503 , 391.87265 , 7.562337], ..., [197.6283 , 160.56392 , 7.807793], [136.51959 , 108.94787 , 7.532552], [167.37955 , 153.43486 , 9.18563 ]], dtype=float32)
votes_senate
array([[ 56, 290, 4], [ 57, 357, 3], [ 97, 383, 5], ..., [192, 160, 11], [138, 106, 11], [176, 147, 14]])
pvote_sim = div_crosstabs.sum(axis=2) / div_crosstabs.sum(axis=2).sum(axis=1, keepdims=True)
pvote_obs = votes_senate / votes_senate.sum(axis=1, keepdims=True)
for j in range(3):
plt.scatter(pvote_obs[:,j], pvote_sim[:,j])
plt.plot([0,1], [0,1], '-y')
plt.show()
div_crosstabs
DeviceArray([[[3.6581982e+01, 1.3336135e+01, 2.8081185e-01], [9.9302607e+00, 2.7906830e+02, 2.2884245e+00], [1.4877583e+00, 2.5955575e+00, 1.4307640e+00]], [[3.6185986e+01, 1.3259668e+01, 1.9848710e-01], [9.3461475e+00, 3.4854770e+02, 1.7496762e+00], [1.4678699e+00, 3.1927319e+00, 1.0518365e+00]], [[6.2378765e+01, 2.3882858e+01, 3.0340764e-01], [1.8072138e+01, 3.7157306e+02, 2.2274449e+00], [2.5491045e+00, 3.5440848e+00, 1.4691478e+00]], ..., [[1.0466409e+02, 9.2670265e+01, 2.9394066e-01], [5.6442543e+01, 1.0334458e+02, 7.7679181e-01], [4.8933730e+00, 1.9851528e+00, 9.2926753e-01]], [[7.2830269e+01, 6.2798084e+01, 8.9123315e-01], [3.9750423e+01, 6.6886887e+01, 2.3105569e+00], [3.4193163e+00, 1.3150253e+00, 2.7982101e+00]], [[9.1268021e+01, 7.5133499e+01, 9.7802913e-01], [4.6543957e+01, 1.0406377e+02, 2.8271279e+00], [4.1880307e+00, 1.8027555e+00, 3.1948438e+00]]], dtype=float32)
from itertools import product
div_crosstabs.reshape(-1)[:9]
res = pd.DataFrame(
product(range(1703), range(3), range(3)),
columns=("warddiv", "sen", "gov")
).assign(
warddiv = lambda res: df.warddiv.iloc[res.warddiv].values,
sen = lambda res: np.array(["Oz", "Fetterman", "Other"])[res.sen],
gov = lambda res: np.array(["Mastriano", "Shapiro", "Other"])[res.gov],
sim = lambda res: div_crosstabs.reshape(-1),
)
display(res)
# res.warddiv=df.warddiv.iloc[res.warddiv.values]
res.to_csv("/Users/jtannen/Documents/sen_gov_sim.csv")
warddiv | sen | gov | sim | |
---|---|---|---|---|
0 | 01-01 | Oz | Mastriano | 36.581982 |
1 | 01-01 | Oz | Shapiro | 13.336135 |
2 | 01-01 | Oz | Other | 0.280812 |
3 | 01-01 | Fetterman | Mastriano | 9.930261 |
4 | 01-01 | Fetterman | Shapiro | 279.068298 |
... | ... | ... | ... | ... |
15322 | 66-46 | Fetterman | Shapiro | 104.063766 |
15323 | 66-46 | Fetterman | Other | 2.827128 |
15324 | 66-46 | Other | Mastriano | 4.188031 |
15325 | 66-46 | Other | Shapiro | 1.802755 |
15326 | 66-46 | Other | Other | 3.194844 |
15327 rows × 4 columns