Key Word(s): ??



In [1]:
import pymc3 as pm
import arviz as az
In [3]:
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats as st
import seaborn as sns
import matplotlib.pyplot as plt
In [28]:
# Ignore a common pymc3 warning that comes from library functions, not our code.
# Pymc3 may throw additional warnings, but other warnings should be manageable
# by following the instructions included within the warning messages.
import warnings

messages=[
    "Using `from_pymc3` without the model will be deprecated in a future release",
]

for m in messages:
    warnings.filterwarnings("ignore", message=m)
In [4]:
n_theta = 10000

# generate 10,000 values from Beta(2,5)
theta = np.random.beta(2,5,n_theta)
print("First five  values of theta:\n\t", theta[0:5])
print("Sample mean:\n\t", np.mean(theta))
print("The 2.5% and 97.5% of quantiles:\n\t", np.percentile(theta,[2.5,97.5]))
First five  values of theta:
	 [0.47610973 0.17998453 0.11623631 0.4425017  0.1875366 ]
Sample mean:
	 0.28306686085461086
The 2.5% and 97.5% of quantiles:
	 [0.04394589 0.63105426]
In [5]:
plt.hist(theta,50)
plt.xlabel("Value of Theta")
plt.ylabel("Count")
plt.show()
In [6]:
# simulate y from posterior predictive distribution
y = np.random.binomial(1, theta, n_theta) # generate a heads/tails value from each of the 10,000 thetas

print("First 5 heads/tails values (tails=0, heads=1)\n\t", y[0:10])
print("Overall frequency of Tails and Heads, accounting for uncertainty about theta itself\n\t", np.bincount(y)/10000)

plt.hist(y, density=True)
plt.xticks([.05,.95],["Tails","Heads"])
plt.show()
First 5 heads/tails values (tails=0, heads=1)
	 [1 1 0 0 0 0 1 0 0 0]
Overall frequency of Tails and Heads, accounting for uncertainty about theta itself
	 [0.7098 0.2902]

Rejection sampling and Weighted bootstrap

Example adapted from https://wiseodd.github.io/techblog/2015/10/21/rejection-sampling/

In [2]:
sns.set()

def h(x):
    return st.norm.pdf(x, loc=30, scale=10) + st.norm.pdf(x, loc=80, scale=20)


def g(x):
    return st.norm.pdf(x, loc=50, scale=30)


x = np.arange(-50, 151)
M = max(h(x) / g(x))  # for rejection sampling

h is a mixture of two normal distributions (unnormalized), and density h is a normal distribution with mean 50 and standard deviation 30.

In [17]:
plt.plot(x, h(x))
plt.show()
In [18]:
# Superimpose h and g on same plot
plt.plot(x,h(x))
plt.plot(x,g(x))
plt.show()
In [19]:
# Superimpose h and M*g on same plot - now M*g envelopes h
plt.plot(x,h(x))
plt.plot(x,M*g(x))
plt.show()
In [5]:
def rejection_sampling(maxiter=10000,sampsize=1000):
    samples = []
    sampcount = 0  # counter for accepted samples
    maxcount = 0   # counter for proposal simulation
    # sampcount/maxcount at any point in the iteration is the acceptance rate

    while (sampcount < sampsize and maxcount < maxiter):
        z = np.random.normal(50, 30)
        u = np.random.uniform(0, 1)
        maxcount += 1

        if u <= h(z)/(M*g(z)):
            samples.append(z)
            sampcount += 1

    print('Rejection rate is',100*(1-sampcount/maxcount))
    if maxcount == maxiter: print('Maximum iterations achieved')
    return np.array(samples)

s = rejection_sampling(maxiter=10000,sampsize=1000)
sns.displot(s)
Rejection rate is 49.54591321897074
Out[5]:
In [25]:
# weighted bootstrap computation involving h and g
import random

def weighted_bootstrap(iter=1000,size=100):
    w = []
    y = []

    for i in range(iter):
        z = np.random.normal(50, 30)
        y.append(z)
        wz = h(z)/g(z)
        w.append(wz)

    v = random.choices(y,weights=w,k=size) # do not need to renormalize w
    return np.array(v)

wb = weighted_bootstrap(iter=10000,size=1000)
sns.displot(wb)
Out[25]:

Beetles

In [30]:
beetles_x = np.array([1.6907, 1.7242, 1.7552, 1.7842, 1.8113, 1.8369, 1.8610, 1.8839])
beetles_x_mean = beetles_x - np.mean(beetles_x)
beetles_n = np.array([59, 60, 62, 56, 63, 59, 62, 60])
beetles_y = np.array([6, 13, 18, 28, 52, 53, 61, 60])
beetles_N = np.array([8]*8)
In [31]:
from scipy.special import expit
expit(2)
Out[31]:
0.8807970779778823
In [32]:
with pm.Model() as beetle_model:
    # The intercept (log probability of beetles dying when dose=0)
    # is centered at zero, and wide-ranging (easily anywhere from 0 to 100%)
    # If we wanted, we could choose something like Normal(-3,2) for a no-dose
    # death rate roughly between .007 and .25
    alpha_star = pm.Normal('alpha*', mu=0, sigma=100)
    # the effect on the log-odds of each unit of the dose is wide-ranging:
    # we're saying we've got little idea what the effect will be, and it could
    # be strongly negative.
    beta = pm.Normal('beta', mu=0, sigma=100)

    # given alpha, beta, and the dosage, the probability of death is deterministic:
    # it's the inverse logit of the intercept+slope*dosage
    # Because beetles_x has 8 entries, we end up with 8 p_i values
    p_i = pm.Deterministic('$P_i$', pm.math.invlogit(alpha_star + beta*beetles_x_mean))

    # finally, the number of bettles we see killed is Binomial(n=number of beetles, p=probability of death)
    deaths = pm.Binomial('obs_deaths', n=beetles_n, p=p_i, observed=beetles_y)

    trace = pm.sample(2000, tune=2000, target_accept=0.9)
/home/glickm/.local/lib/python3.9/site-packages/pymc3/sampling.py:465: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  warnings.warn(
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [beta, alpha*]
100.00% [8000/8000 00:02<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 2_000 tune and 2_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
In [33]:
az.plot_trace(trace, compact=False);
In [34]:
def trace_summary(trace, var_names=None):
    if var_names is None:
        var_names = trace.varnames

    quants = [0.025,0.25,0.5,0.75,0.975]
    colnames = ['mean', 'sd', *["{}%".format(x*100) for x in quants]]
    rownames = []

    series = []
    for cur_var in var_names:
        var_trace = trace[cur_var]
        if var_trace.ndim == 1:
            vals = [np.mean(var_trace, axis=0), np.std(var_trace, axis=0), *np.quantile(var_trace, quants, axis=0)]
            series.append(pd.Series(vals, colnames))
            rownames.append(cur_var)
        else:
            for i in range(var_trace.shape[1]):
                cur_col = var_trace[:,i]
                vals = [np.mean(cur_col, axis=0), np.std(cur_col, axis=0), *np.quantile(cur_col, quants, axis=0)]
                series.append(pd.Series(vals, colnames))
                rownames.append("{}[{}]".format(cur_var,i))

    return pd.DataFrame(series, index=rownames)

trace_summary(trace)
Out[34]:
mean sd 2.5% 25.0% 50.0% 75.0% 97.5%
alpha* 0.748408 0.142126 0.484625 0.647875 0.746483 0.844844 1.035891
beta 34.598944 2.952582 29.160611 32.533871 34.531622 36.484915 40.660437
$P_i$[0] 0.059033 0.016087 0.032537 0.047387 0.057374 0.068892 0.094452
$P_i$[1] 0.163540 0.028339 0.111428 0.143686 0.162045 0.182056 0.220862
$P_i$[2] 0.361014 0.034651 0.294708 0.337156 0.360277 0.385165 0.429006
$P_i$[3] 0.605225 0.032339 0.543134 0.582635 0.605705 0.627732 0.668252
$P_i$[4] 0.795550 0.026790 0.744184 0.777235 0.796137 0.814087 0.846153
$P_i$[5] 0.903212 0.018882 0.864207 0.891043 0.904234 0.916307 0.937264
$P_i$[6] 0.954857 0.011829 0.929360 0.947438 0.955950 0.963315 0.974979
$P_i$[7] 0.978645 0.007007 0.962990 0.974396 0.979538 0.983683 0.989899

We can also plot the density each chain explored. Any major deviations between chains are signs of difficulty converging.

In [35]:
for x in trace.varnames:
    az.plot_forest(trace, var_names=[x], combined=True)

In addition to the above summaries of the distribution, pymc3 has statistics intended to summarize the quality of the samples. The most common of these is r_hat, which measures whether the different chains seem to be exploring the same space or if they're stuck in different spaces. R-hat above 1.3 is a strong sign the sample isn't good yet. Values close to 1 are ideal.

In [16]:
az.summary(trace)
/home/glickm/.local/lib/python3.9/site-packages/arviz/data/io_pymc3.py:87: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
Out[16]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
alpha* 0.750 0.139 0.499 1.020 0.003 0.002 2840.0 2696.0 2867.0 2296.0 1.0
beta 34.626 2.963 28.628 39.807 0.055 0.039 2905.0 2841.0 2945.0 2388.0 1.0
$P_i$[0] 0.059 0.016 0.030 0.088 0.000 0.000 3520.0 3433.0 3458.0 2536.0 1.0
$P_i$[1] 0.163 0.028 0.111 0.214 0.000 0.000 3695.0 3675.0 3666.0 2439.0 1.0
$P_i$[2] 0.361 0.034 0.295 0.423 0.001 0.000 3701.0 3678.0 3714.0 3039.0 1.0
$P_i$[3] 0.606 0.031 0.550 0.666 0.001 0.000 3081.0 3035.0 3101.0 2536.0 1.0
$P_i$[4] 0.796 0.026 0.747 0.844 0.001 0.000 2712.0 2693.0 2677.0 2360.0 1.0
$P_i$[5] 0.903 0.019 0.868 0.938 0.000 0.000 2687.0 2680.0 2632.0 2055.0 1.0
$P_i$[6] 0.955 0.012 0.933 0.977 0.000 0.000 2722.0 2720.0 2653.0 2112.0 1.0
$P_i$[7] 0.979 0.007 0.966 0.991 0.000 0.000 2760.0 2759.0 2680.0 2122.0 1.0

Sleep Study

In [17]:
import pandas as pd
sleepstudy = pd.read_csv("sleepstudy.csv")
In [18]:
sleepstudy
Out[18]:
Reaction Days Subject
0 249.5600 0 308
1 258.7047 1 308
2 250.8006 2 308
3 321.4398 3 308
4 356.8519 4 308
... ... ... ...
175 329.6076 5 372
176 334.4818 6 372
177 343.2199 7 372
178 369.1417 8 372
179 364.1236 9 372

180 rows × 3 columns

In [19]:
# adding a column that numbers the subjects from 0 to n
raw_ids = np.unique(sleepstudy['Subject'])
raw2newid = {x:np.where(raw_ids == x)[0][0] for x in raw_ids}

sleepstudy['SeqSubject'] = sleepstudy['Subject'].map(raw2newid)
sleepstudy
Out[19]:
Reaction Days Subject SeqSubject
0 249.5600 0 308 0
1 258.7047 1 308 0
2 250.8006 2 308 0
3 321.4398 3 308 0
4 356.8519 4 308 0
... ... ... ... ...
175 329.6076 5 372 17
176 334.4818 6 372 17
177 343.2199 7 372 17
178 369.1417 8 372 17
179 364.1236 9 372 17

180 rows × 4 columns

In [20]:
with pm.Model() as sleep_model:

    # In this model, we're going to say the alphas (individuals' intercepts; their starting reaction time)
    # and betas (individuals' slopes; how much worse they get with lack of sleep) are normally distributed.
    # We'll specify that we're certain about the mean of those distribution [more on that later], but admit
    # we're uncertain about how much spread there is (i.e. uncertain about the SD). Tau_alpha and Tau_beta 
    # will be the respective SD.
    #
    # Of course, the SDs must be positive (negative SD isn't mathematically possible), so we draw them from
    # a Gamma, which cannot ever output negative numbers. Here, we use alpha and beta values that spread the
    # distribution: "the SD could be anything!". If we had more intuition (e.g. "the starting reaction times can't
    # have SD above 3,000") we would plot Gamma(a,b) and tune the parameters so that there was little mass
    # above 3,000, then use those values below)
    tau_alpha = pm.Gamma('tau_alpha', alpha=.001, beta=.001)
    tau_beta = pm.Gamma('tau_beta', alpha=.001, beta=.001)

    # Across the population of people, we suppose that
    # the slopes are normally distributed, as are the intercepts,
    # and the two are drawn independently
    #
    # (Here, we hard-code assumed means, but we don't have to.
    # In general, these should be set from our pre-data intuition,
    # rather than from plots/exploration of the data)
    alpha = pm.Normal('alpha', mu=300, tau=tau_alpha, shape=len(raw_ids))
    beta = pm.Normal('beta', mu=10, tau=tau_beta, shape=len(raw_ids))

    # Remember: there's only one alpha/beta per person, but
    # we have lots of observations per person. The below
    # builds a vector with one entry per observation, recording
    # the alpha/beta we want to use with that observation.
    #
    # That is, the length is 180, but it only has 17 unique values,
    # matching the 17 unique patients' personal slopes or intercepts
    intercepts = alpha[sleepstudy['SeqSubject']]
    slopes = beta[sleepstudy['SeqSubject']]

    # now we have the true/predicted response time for each observation (each row of original data)
    # (Here we use pm.Deterministic to signal that this is something we'll care about)
    mu_i = pm.Deterministic('mu_i', intercepts + slopes*sleepstudy['Days'])

    # The _observed_ values are noisy versions of the hidden true values, however! 
    # Specifically, we model them as a normal at the true value and single unknown variance
    # (one explanation: we're saying the measurement equipment adds normally-distributed noise tau_obs
    # so noise doesn't vary from observation to observation or person to person: there's just one universal
    # noise level)
    tau_obs = pm.Gamma('tau_obs', 0.001, 0.001)
    obs = pm.Normal('observed', mu=mu_i, tau=tau_obs, observed=sleepstudy['Reaction'])

    trace = pm.sample(2000, tune=2000, target_accept=0.9)
/home/glickm/.local/lib/python3.9/site-packages/pymc3/sampling.py:465: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  warnings.warn(
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [tau_obs, beta, alpha, tau_beta, tau_alpha]
100.00% [8000/8000 00:16<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 2_000 tune and 2_000 draw iterations (4_000 + 4_000 draws total) took 17 seconds.
In [21]:
# this command can take a few minutes to finish... or never :-/
#az.plot_trace(trace);
In [22]:
trace_summary(trace, var_names=['tau_alpha', 'tau_beta', 'alpha', 'beta', 'tau_obs'])
Out[22]:
mean sd 2.5% 25.0% 50.0% 75.0% 97.5%
tau_alpha 0.000346 0.000126 0.000148 0.000256 0.000330 0.000415 0.000632
tau_beta 0.033659 0.015637 0.013079 0.023094 0.030559 0.040748 0.072732
alpha[0] 258.152009 14.085221 229.587865 248.875997 258.031798 267.548487 285.905580
alpha[1] 203.987112 13.798313 177.167592 194.562082 204.073271 213.510437 231.477768
alpha[2] 206.126645 13.890339 179.265453 196.916117 205.844053 215.249876 233.399891
alpha[3] 283.997308 13.540079 257.583370 274.934328 283.951485 293.016335 310.396503
alpha[4] 282.410892 13.518552 255.264925 273.224767 282.552675 291.504971 308.997292
alpha[5] 265.806415 13.531943 239.467840 256.855144 265.607351 274.862315 292.612168
alpha[6] 275.757596 13.320304 249.045762 267.021728 275.787887 284.376585 302.710662
alpha[7] 245.740413 13.714924 219.631293 236.229034 245.464725 255.177507 272.645132
alpha[8] 253.870482 14.255955 225.796467 244.251182 253.993258 263.769349 281.444838
alpha[9] 298.553345 13.755800 271.078080 289.243294 298.823119 307.431320 325.569036
alpha[10] 223.610381 13.990920 196.093509 214.008280 223.561530 233.307163 251.138863
alpha[11] 238.871549 13.761936 211.447031 229.581532 239.062261 247.913414 265.707026
alpha[12] 260.368873 13.413522 233.823679 251.571895 260.526191 269.315076 285.961425
alpha[13] 280.742010 13.829812 253.175439 271.743084 280.698178 289.996805 307.893626
alpha[14] 258.800724 13.492456 231.671886 249.901913 258.941095 267.936373 285.521491
alpha[15] 223.528867 14.165204 196.547061 213.777409 223.291463 233.232906 251.393429
alpha[16] 255.784792 13.460106 229.445291 246.648575 255.697083 265.101533 282.197640
alpha[17] 270.456634 13.558380 243.777977 261.344061 270.401545 279.587780 297.676718
beta[0] 18.880788 2.636258 13.763675 17.116352 18.826027 20.619523 24.186513
beta[1] 2.971997 2.571664 -2.007634 1.211468 2.980907 4.680537 8.058643
beta[2] 6.005203 2.555142 0.961511 4.326752 6.030397 7.745092 10.972043
beta[3] 4.308504 2.497274 -0.653887 2.619826 4.312722 5.983469 9.216847
beta[4] 6.111156 2.504019 1.258975 4.388857 6.113937 7.779817 10.971868
beta[5] 9.295519 2.463106 4.511339 7.626141 9.290686 10.936594 14.084005
beta[6] 9.101971 2.422311 4.292805 7.465481 9.102780 10.698237 13.893795
beta[7] 11.260015 2.500966 6.308763 9.555129 11.297954 12.939190 16.112974
beta[8] -0.646928 2.663938 -5.951725 -2.427896 -0.650508 1.184857 4.604859
beta[9] 17.165573 2.559193 12.221724 15.413945 17.140421 18.823903 22.378692
beta[10] 11.949868 2.529520 6.957955 10.273520 11.929817 13.635788 16.907556
beta[11] 16.909605 2.535119 11.934900 15.232328 16.884626 18.602441 21.954967
beta[12] 6.782651 2.460363 1.774711 5.181803 6.800000 8.406616 11.625047
beta[13] 12.685451 2.551037 7.689587 10.984800 12.725454 14.378128 17.749365
beta[14] 10.666371 2.469020 5.793505 9.070583 10.656603 12.269673 15.562084
beta[15] 15.553713 2.563791 10.643113 13.771380 15.569299 17.274128 20.549376
beta[16] 8.910774 2.469050 4.100091 7.209558 8.957572 10.572388 13.671468
beta[17] 10.708832 2.500048 5.846379 8.982458 10.748232 12.420740 15.558593
tau_obs 0.001517 0.000178 0.001191 0.001391 0.001511 0.001637 0.001882
In [23]:
az.summary(trace, var_names=['tau_alpha', 'tau_beta', 'alpha', 'beta', 'tau_obs'])
/home/glickm/.local/lib/python3.9/site-packages/arviz/data/io_pymc3.py:87: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
Out[23]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
tau_alpha 0.000 0.000 0.000 0.001 0.000 0.000 5872.0 5400.0 5408.0 2903.0 1.0
tau_beta 0.034 0.016 0.011 0.061 0.000 0.000 2799.0 2560.0 3067.0 3254.0 1.0
alpha[0] 258.152 14.087 229.981 283.320 0.208 0.147 4598.0 4577.0 4615.0 2671.0 1.0
alpha[1] 203.987 13.800 178.345 230.169 0.237 0.167 3402.0 3402.0 3405.0 3394.0 1.0
alpha[2] 206.127 13.892 180.632 232.780 0.215 0.152 4179.0 4179.0 4146.0 3045.0 1.0
alpha[3] 283.997 13.542 259.688 310.220 0.199 0.141 4626.0 4626.0 4629.0 3100.0 1.0
alpha[4] 282.411 13.520 256.634 307.660 0.214 0.152 3982.0 3982.0 3961.0 3202.0 1.0
alpha[5] 265.806 13.534 242.135 293.297 0.208 0.147 4222.0 4216.0 4222.0 3289.0 1.0
alpha[6] 275.758 13.322 249.853 300.722 0.212 0.150 3949.0 3949.0 3924.0 3194.0 1.0
alpha[7] 245.740 13.717 220.770 271.740 0.215 0.152 4079.0 4079.0 4080.0 3263.0 1.0
alpha[8] 253.870 14.258 227.625 281.013 0.243 0.172 3438.0 3438.0 3437.0 3205.0 1.0
alpha[9] 298.553 13.758 272.172 324.239 0.243 0.172 3214.0 3211.0 3210.0 2818.0 1.0
alpha[10] 223.610 13.993 198.706 251.420 0.212 0.150 4359.0 4359.0 4353.0 2827.0 1.0
alpha[11] 238.872 13.764 211.660 263.609 0.228 0.161 3643.0 3642.0 3649.0 2780.0 1.0
alpha[12] 260.369 13.415 235.299 285.691 0.216 0.153 3868.0 3859.0 3872.0 3168.0 1.0
alpha[13] 280.742 13.832 254.178 305.953 0.223 0.157 3858.0 3858.0 3863.0 3249.0 1.0
alpha[14] 258.801 13.494 233.524 284.470 0.210 0.149 4111.0 4111.0 4126.0 2873.0 1.0
alpha[15] 223.529 14.167 196.548 249.253 0.228 0.162 3845.0 3831.0 3850.0 3142.0 1.0
alpha[16] 255.785 13.462 230.441 280.705 0.217 0.154 3846.0 3844.0 3846.0 2647.0 1.0
alpha[17] 270.457 13.560 243.626 294.911 0.212 0.150 4085.0 4085.0 4076.0 3205.0 1.0
beta[0] 18.881 2.637 13.858 23.717 0.042 0.030 3984.0 3973.0 3982.0 3262.0 1.0
beta[1] 2.972 2.572 -1.896 7.753 0.044 0.033 3481.0 3019.0 3469.0 3303.0 1.0
beta[2] 6.005 2.555 1.354 10.922 0.038 0.029 4423.0 4020.0 4425.0 3095.0 1.0
beta[3] 4.309 2.498 -0.366 9.042 0.038 0.030 4425.0 3469.0 4426.0 2886.0 1.0
beta[4] 6.111 2.504 1.495 10.792 0.041 0.030 3721.0 3538.0 3718.0 3168.0 1.0
beta[5] 9.296 2.463 4.766 14.010 0.038 0.028 4142.0 3784.0 4135.0 2781.0 1.0
beta[6] 9.102 2.423 4.659 13.710 0.037 0.028 4262.0 3842.0 4255.0 3053.0 1.0
beta[7] 11.260 2.501 6.322 15.730 0.039 0.028 4029.0 3908.0 4023.0 3078.0 1.0
beta[8] -0.647 2.664 -5.742 4.330 0.045 0.038 3478.0 2423.0 3472.0 3124.0 1.0
beta[9] 17.166 2.560 12.338 22.019 0.044 0.031 3436.0 3436.0 3432.0 3106.0 1.0
beta[10] 11.950 2.530 7.478 17.014 0.038 0.027 4385.0 4295.0 4385.0 2637.0 1.0
beta[11] 16.910 2.535 12.220 21.790 0.043 0.031 3493.0 3446.0 3502.0 2766.0 1.0
beta[12] 6.783 2.461 2.333 11.682 0.040 0.028 3878.0 3846.0 3875.0 3166.0 1.0
beta[13] 12.685 2.551 7.697 17.322 0.041 0.029 3883.0 3753.0 3885.0 3076.0 1.0
beta[14] 10.666 2.469 5.744 15.045 0.038 0.027 4124.0 4114.0 4121.0 3187.0 1.0
beta[15] 15.554 2.564 10.636 20.174 0.042 0.030 3726.0 3709.0 3714.0 3312.0 1.0
beta[16] 8.911 2.469 4.372 13.518 0.040 0.028 3833.0 3767.0 3865.0 2918.0 1.0
beta[17] 10.709 2.500 6.022 15.286 0.038 0.027 4260.0 4192.0 4266.0 3048.0 1.0
tau_obs 0.002 0.000 0.001 0.002 0.000 0.000 4092.0 4028.0 4113.0 3219.0 1.0
In [24]:
import statsmodels.formula.api as sm
import seaborn as sns
from matplotlib import gridspec


ymin,ymax = np.min(sleepstudy["Reaction"]),np.max(sleepstudy["Reaction"])
plt.figure(figsize=(11,8.5))
gs  = gridspec.GridSpec(3, 6)
gs.update(wspace=0.5, hspace=0.5)
for i, subj in enumerate(np.unique(sleepstudy['Subject'])):
    ss_extract = sleepstudy.loc[sleepstudy['Subject']==subj]
    ss_extract_ols = sm.ols(formula="Reaction~Days",data=ss_extract).fit()
    #new subplot
    subplt = plt.subplot(gs[i])
    #plot without confidence intervals
    sns.regplot(x='Days', y='Reaction', ci=None, data=ss_extract).set_title('Subject '+str(subj))
    if i not in [0,6,12]:
        plt.ylabel("")
    i+=1
    subplt.set_ylim(ymin,ymax)

_ = plt.figlegend(['Estimated from each subject alone'],loc = 'lower center', ncol=6)
_ = plt.show()
In [25]:
plt.figure(figsize=(11,8.5))
for i, subj in enumerate(np.unique(sleepstudy['Subject'])):
    ss_extract = sleepstudy.loc[sleepstudy['Subject']==subj]
    #new subplot
    subplt = plt.subplot(gs[i])
    #plot without confidence intervals
    sns.regplot(x='Days', y='Reaction', ci=None, data=ss_extract).set_title('Subject '+str(subj))
    sns.regplot(x='Days', y='Reaction', ci=None, scatter=False, data=sleepstudy)
    if i not in [0,6,12]:
        plt.ylabel("")
    i+=1
    subplt.set_ylim(ymin,ymax)

_ = plt.figlegend(['Estimated from each subject alone','Pooling all subjects'],loc = 'lower center', ncol=6)
_ = plt.show()
In [26]:
plt.figure(figsize=(11,8.5))
subj_arr = np.unique(sleepstudy['Subject'])
for i, subj in enumerate(subj_arr):
    ss_extract = sleepstudy.loc[sleepstudy['Subject']==subj]
    #new subplot
    subplt = plt.subplot(gs[i])

    #plot without confidence intervals
    sns.regplot(x='Days', y='Reaction', ci=None, data=ss_extract).set_title('Subject '+str(subj))
    sns.regplot(x='Days', y='Reaction', ci=None, scatter=False, data=sleepstudy)

    subj_num = int(np.where(subj_arr==subj)[0])

    subjects_avg_intercept = np.mean(trace['alpha'][:,i])
    subjects_avg_slope = np.mean(trace['beta'][:,i])
    hmodel_fit = [subjects_avg_intercept + subjects_avg_slope*x for x in range(-1,11)]
    sns.lineplot(x=range(-1,11),y=hmodel_fit)
    if i not in [0,6,12]:
        plt.ylabel("")
    i+=1
    subplt.set_ylim(ymin,ymax)

_ = plt.figlegend(['Estimated from each subject alone','Pooling all subjects','Hierarchical (partial pooling)'],loc = 'lower center', ncol=6)
_ = plt.show()
In [27]:
model_predictions = trace['mu_i'].mean(axis=0)
obs_reactions = sleepstudy['Reaction']

plt.figure(figsize=(11,8.5))
plt.scatter(sleepstudy['Reaction'], model_predictions)
plt.plot(plt.xlim(), plt.ylim(), c='black')
plt.xlabel("Observed Reaction Time (ms)")
plt.ylabel("Predicted Reaction Time [Mean of Posterior] (ms)")
plt.title("Observed and Fitted Reaction Times from . Bayesian Hierarchical Model")
plt.show()
In [ ]: