In [1]:
import numpy as np
import pandas as pd
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import jax.nn as nn
import jax.numpy as jnp
import jax.random as random
In [3]:
df = pd.read_csv("/Users/jtannen/Documents/sixtysix/div_gov_sen.csv")
In [6]:
def expit(x):
    return 1 / (1 + jnp.exp(-x))
In [143]:
def model(votes_senate, votes_gov, pvote_mastriano):
    # N_div * N_sen * N_gov
    beta = numpyro.sample("beta", dist.Normal(jnp.zeros([3, 3]), 1)) 
    alpha = numpyro.sample("alpha", dist.Normal(jnp.zeros([3, 3]), 1))
    probs_raw = expit(alpha[np.newaxis,:,:] + beta[np.newaxis,:,:] * pvote_mastriano.reshape(-1,1,1))
    probs = numpyro.deterministic("probs", probs_raw / probs_raw.sum(axis=1, keepdims=True))
    
    means = jnp.einsum("dsg,dg->ds", probs_raw, votes_gov)
    variances = jnp.einsum("dsg,dg->ds", probs_raw * (1-probs_raw), votes_gov)
    obs = numpyro.sample("obs", dist.Normal(means, jnp.sqrt(variances)), obs=votes_senate)
    
    loglik = numpyro.deterministic("loglik", dist.Normal(means, jnp.sqrt(variances)).log_prob(votes_senate).sum())
In [154]:
np.random.seed(215)

votes_gov=df[["votes_Mastriano", "votes_Shapiro", "votes_GOVERNOR - Other Candidate"]].values
votes_senate=df[["votes_Oz", "votes_Fetterman", "votes_UNITED STATES SENATOR - Other Candidate"]].values

# Run the model
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
mcmc.run(
    rng_key=random.PRNGKey(215), 
    votes_senate=votes_senate,
    votes_gov=votes_gov, 
    pvote_mastriano=df.pvote_Mastriano.values,
)
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:16<00:00, 90.80it/s, 63 steps of size 8.54e-02. acc. prob=0.93]
In [156]:
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
alpha[0,0]      2.42      0.08      2.41      2.28      2.56    913.05      1.00
alpha[0,1]     -4.30      0.03     -4.30     -4.34     -4.26    704.10      1.00
alpha[0,2]     -2.77      0.39     -2.72     -3.35     -2.16    948.58      1.00
alpha[1,0]     -1.45      0.08     -1.45     -1.59     -1.33    758.07      1.00
alpha[1,1]      3.67      0.02      3.67      3.63      3.71    750.10      1.00
alpha[1,2]      1.17      0.17      1.16      0.91      1.45   1010.79      1.00
alpha[2,0]     -3.28      0.26     -3.26     -3.68     -2.83    597.04      1.00
alpha[2,1]     -4.74      0.04     -4.75     -4.80     -4.68    910.47      1.00
alpha[2,2]     -0.44      0.10     -0.44     -0.60     -0.27    734.86      1.00
 beta[0,0]     -3.65      0.18     -3.65     -3.92     -3.35    929.15      1.00
 beta[0,1]      8.90      0.12      8.90      8.68      9.08   1223.30      1.00
 beta[0,2]      1.95      1.07      1.95      0.30      3.79   1191.00      1.00
 beta[1,0]      2.00      0.15      2.00      1.77      2.26    873.53      1.00
 beta[1,1]     -8.24      0.10     -8.24     -8.41     -8.07   1229.53      1.00
 beta[1,2]     -3.59      0.82     -3.60     -4.93     -2.24   1052.71      1.00
 beta[2,0]     -0.33      0.32     -0.34     -0.81      0.24    741.25      1.00
 beta[2,1]     -0.08      1.11      0.00     -1.87      1.69    669.75      1.00
 beta[2,2]      0.63      0.54      0.64     -0.15      1.59    816.44      1.00

Number of divergences: 0
In [157]:
posterior_samples = mcmc.get_samples()
posterior_samples['loglik'][:10]
Out[157]:
DeviceArray([-14386.096, -14382.528, -14388.564, -14380.695, -14387.931,
             -14394.938, -14383.229, -14390.517, -14390.294, -14391.021],            dtype=float32)
In [158]:
from matplotlib import pyplot as plt
plt.plot(posterior_samples['loglik'])
Out[158]:
[<matplotlib.lines.Line2D at 0x296a916a0>]
In [159]:
div_probs = posterior_samples["probs"].mean(axis=0)
div_probs
Out[159]:
DeviceArray([[[0.7621246 , 0.04520724, 0.07020296],
              [0.20688044, 0.94599426, 0.5721061 ],
              [0.03099496, 0.0087985 , 0.357691  ]],

             [[0.76991457, 0.03632786, 0.06616237],
              [0.19885421, 0.9549252 , 0.5832254 ],
              [0.03123127, 0.00874721, 0.35061216]],

             [[0.7515514 , 0.05985679, 0.07585191],
              [0.2177366 , 0.93126076, 0.5568612 ],
              [0.0307121 , 0.00888242, 0.36728695]],

             ...,

             [[0.6305066 , 0.46803164, 0.14697033],
              [0.34001532, 0.5219423 , 0.3883959 ],
              [0.02947815, 0.01002602, 0.46463376]],

             [[0.62784714, 0.47937468, 0.14853886],
              [0.34267607, 0.5105869 , 0.38509282],
              [0.02947686, 0.01003836, 0.46636835]],

             [[0.64273256, 0.4151022 , 0.13971844],
              [0.32777435, 0.57493794, 0.4038754 ],
              [0.02949318, 0.00995998, 0.45640627]]], dtype=float32)
In [178]:
div_crosstabs = jnp.einsum("dsg,dg->dsg", div_probs, votes_gov)
In [164]:
div_crosstabs.sum(axis=2)
Out[164]:
DeviceArray([[ 50.19893 , 291.287   ,   5.51408 ],
             [ 49.644142, 359.64352 ,   5.712438],
             [ 86.56503 , 391.87265 ,   7.562337],
             ...,
             [197.6283  , 160.56392 ,   7.807793],
             [136.51959 , 108.94787 ,   7.532552],
             [167.37955 , 153.43486 ,   9.18563 ]], dtype=float32)
In [165]:
votes_senate
Out[165]:
array([[ 56, 290,   4],
       [ 57, 357,   3],
       [ 97, 383,   5],
       ...,
       [192, 160,  11],
       [138, 106,  11],
       [176, 147,  14]])
In [177]:
pvote_sim = div_crosstabs.sum(axis=2) / div_crosstabs.sum(axis=2).sum(axis=1, keepdims=True)
pvote_obs = votes_senate / votes_senate.sum(axis=1, keepdims=True)

for j in range(3):
    plt.scatter(pvote_obs[:,j], pvote_sim[:,j])
    plt.plot([0,1], [0,1], '-y')
    plt.show()
In [179]:
div_crosstabs
Out[179]:
DeviceArray([[[3.6581982e+01, 1.3336135e+01, 2.8081185e-01],
              [9.9302607e+00, 2.7906830e+02, 2.2884245e+00],
              [1.4877583e+00, 2.5955575e+00, 1.4307640e+00]],

             [[3.6185986e+01, 1.3259668e+01, 1.9848710e-01],
              [9.3461475e+00, 3.4854770e+02, 1.7496762e+00],
              [1.4678699e+00, 3.1927319e+00, 1.0518365e+00]],

             [[6.2378765e+01, 2.3882858e+01, 3.0340764e-01],
              [1.8072138e+01, 3.7157306e+02, 2.2274449e+00],
              [2.5491045e+00, 3.5440848e+00, 1.4691478e+00]],

             ...,

             [[1.0466409e+02, 9.2670265e+01, 2.9394066e-01],
              [5.6442543e+01, 1.0334458e+02, 7.7679181e-01],
              [4.8933730e+00, 1.9851528e+00, 9.2926753e-01]],

             [[7.2830269e+01, 6.2798084e+01, 8.9123315e-01],
              [3.9750423e+01, 6.6886887e+01, 2.3105569e+00],
              [3.4193163e+00, 1.3150253e+00, 2.7982101e+00]],

             [[9.1268021e+01, 7.5133499e+01, 9.7802913e-01],
              [4.6543957e+01, 1.0406377e+02, 2.8271279e+00],
              [4.1880307e+00, 1.8027555e+00, 3.1948438e+00]]],            dtype=float32)
In [205]:
from itertools import product
In [232]:
div_crosstabs.reshape(-1)[:9]

res = pd.DataFrame(
    product(range(1703), range(3), range(3)),
    columns=("warddiv", "sen", "gov")
).assign(
    warddiv = lambda res: df.warddiv.iloc[res.warddiv].values,
    sen = lambda res: np.array(["Oz", "Fetterman", "Other"])[res.sen],
    gov = lambda res: np.array(["Mastriano", "Shapiro", "Other"])[res.gov],
    sim = lambda res: div_crosstabs.reshape(-1),
)
display(res)
# res.warddiv=df.warddiv.iloc[res.warddiv.values]
res.to_csv("/Users/jtannen/Documents/sen_gov_sim.csv")
warddiv sen gov sim
0 01-01 Oz Mastriano 36.581982
1 01-01 Oz Shapiro 13.336135
2 01-01 Oz Other 0.280812
3 01-01 Fetterman Mastriano 9.930261
4 01-01 Fetterman Shapiro 279.068298
... ... ... ... ...
15322 66-46 Fetterman Shapiro 104.063766
15323 66-46 Fetterman Other 2.827128
15324 66-46 Other Mastriano 4.188031
15325 66-46 Other Shapiro 1.802755
15326 66-46 Other Other 3.194844

15327 rows × 4 columns