Download this notebook from SiddharthaPradhan/stanbkt

Multi BKT Example Using Simulated Grouped Data#

This walkthrough mirrors the simple example (02_simple_example.ipynb) but uses grouped data and MultiBKT.

In this notebook, we will:

  1. Simulate grouped BKT data with group- and KC-specific parameters.

  2. Instantiate and fit a MultiBKT model.

  3. Generate Numba-accelerated point-estimate predictions (predict and predict_smoothed).

This example stops after fitting and Numba-based predictions.

Simulate Grouped BKT Data#

We simulate students from multiple groups. Each group can have different BKT parameters, and in this example, parameters also vary across KCs.

[1]:
from stanbkt.utils import sim_grouped_BKT
[2]:
N_GROUPS = 3
N_KCS = 2

# rows=groups, cols=KCs
bkt_params = {
    "prior": [[0.20, 0.10], [0.35, 0.25], [0.50, 0.40]],
    "learn": [[0.03, 0.06], [0.05, 0.08], [0.08, 0.10]],
    "forget": [[0.01, 0.01], [0.02, 0.02], [0.02, 0.03]],
    "guess": [[0.25, 0.20], [0.20, 0.15], [0.15, 0.10]],
    "slip": [[0.10, 0.08], [0.08, 0.07], [0.06, 0.05]],
}
[3]:
data_df = sim_grouped_BKT(
    n_students=60,
    n_problems=80,
    n_kcs=N_KCS,
    n_groups=N_GROUPS,
    frac=0.85,
    rng_seed=12345,
    **bkt_params,
)
[4]:
data_df.head(10)
[4]:
student_id problem_id correct timestamp kc_id group_id
0 stu_0 prob_0 0 2024-01-01 00:00:00 kc_1 group_0
1 stu_0 prob_1 1 2024-01-01 00:01:00 kc_0 group_0
2 stu_0 prob_2 0 2024-01-01 00:02:00 kc_1 group_0
3 stu_0 prob_3 1 2024-01-01 00:03:00 kc_0 group_0
4 stu_0 prob_4 1 2024-01-01 00:04:00 kc_0 group_0
5 stu_0 prob_6 0 2024-01-01 00:06:00 kc_1 group_0
6 stu_0 prob_7 1 2024-01-01 00:07:00 kc_1 group_0
7 stu_0 prob_8 1 2024-01-01 00:08:00 kc_1 group_0
8 stu_0 prob_9 1 2024-01-01 00:09:00 kc_0 group_0
9 stu_0 prob_10 0 2024-01-01 00:10:00 kc_1 group_0
[5]:
data_df[["group_id", "kc_id"]].value_counts().sort_index()
[5]:
group_id  kc_id
group_0   kc_0     658
          kc_1     700
group_1   kc_0     663
          kc_1     705
group_2   kc_0     656
          kc_1     698
Name: count, dtype: int64

MultiBKT expects long-format data including the standard columns (student_id, problem_id, correct, timestamp, kc_id) plus group_id.

[6]:
required_cols = [
    "student_id",
    "problem_id",
    "correct",
    "timestamp",
    "kc_id",
    "group_id",
]
data_df[required_cols].head(5)
[6]:
student_id problem_id correct timestamp kc_id group_id
0 stu_0 prob_0 0 2024-01-01 00:00:00 kc_1 group_0
1 stu_0 prob_1 1 2024-01-01 00:01:00 kc_0 group_0
2 stu_0 prob_2 0 2024-01-01 00:02:00 kc_1 group_0
3 stu_0 prob_3 1 2024-01-01 00:03:00 kc_0 group_0
4 stu_0 prob_4 1 2024-01-01 00:04:00 kc_0 group_0

Instantiate and fit MultiBKT#

As in the simple example, Stan code is compiled lazily on the first fit(...) call and cached for reuse.

[7]:
from stanbkt.models import MultiBKT
from stanbkt.fits import FitMethod, MCMCFitOptions
from stanbkt.models import MultiPriors
from stanbkt.utils import VerbosityLevel
[8]:
model = MultiBKT(
    fit_method=FitMethod.MCMC,
    verbose=VerbosityLevel.WARN,
)

fit_opts = MCMCFitOptions(
    seed=1234,
    iter_warmup=500,
    iter_sampling=500,
)

priors = MultiPriors(use_defaults=True)  # using default priors for all groups and KCs

Unlike the previous example (02_simple_example.ipynb), we will pass the entire DataFrame containing both KCs.

[9]:
model.fit(data_df, stan_fit_options=fit_opts, priors=priors)
17:00:36 - cmdstanpy - INFO - CmdStan start processing

17:00:47 - cmdstanpy - INFO - CmdStan done processing.
17:00:47 - cmdstanpy - INFO - CmdStan start processing


17:00:55 - cmdstanpy - INFO - CmdStan done processing.

[9]:
MultiBKT(fit_method=<FitMethod.MCMC: 'mcmc'>, verbose=<VerbosityLevel.WARN: 1>, is_fitted=True)

Generate a summary of the parameters’ posterior distributions.

[10]:
summary = model.summary()
summary
[10]:
Mean MCSE StdDev MAD 2.5% 50% 97.5% ESS_bulk ESS_tail R_hat
kc_id parameter
kc_1 lp__ -921.694000 0.118435 3.010790 2.884400 -928.645000 -921.360000 -916.903000 660.542 843.480 1.002110
logit_pi_know_group[1] -3.428550 0.085990 2.119080 1.248890 -9.653540 -2.811810 -1.208540 1028.270 871.355 1.008320
logit_pi_know_group[2] -0.953987 0.015008 0.626301 0.585673 -2.271920 -0.910853 0.165178 2127.500 1143.810 1.001320
logit_pi_know_group[3] 0.063329 0.009998 0.541030 0.517822 -1.026090 0.077880 1.113430 2983.860 1291.860 0.999119
logit_learn_group[1] -2.610510 0.004827 0.250038 0.254466 -3.126990 -2.605160 -2.150770 2962.300 1118.880 1.001850
... ... ... ... ... ... ... ... ... ... ... ...
kc_0 guess[2] 0.206578 0.000706 0.035417 0.036829 0.141762 0.205580 0.279148 2559.340 1405.850 1.003210
guess[3] 0.148244 0.000900 0.039277 0.039619 0.079686 0.146121 0.233238 1955.870 1367.580 1.001490
slip[1] 0.096202 0.000450 0.019514 0.018577 0.060655 0.095115 0.138666 1905.790 1533.050 1.002140
slip[2] 0.048679 0.000269 0.012363 0.012165 0.027220 0.047605 0.075275 2159.430 1354.460 1.000900
slip[3] 0.039134 0.000215 0.010497 0.010506 0.020468 0.038814 0.061320 2326.940 1208.900 1.001970

62 rows × 10 columns

[11]:
summary.loc["kc_0"]  # summary for kc_0
[11]:
Mean MCSE StdDev MAD 2.5% 50% 97.5% ESS_bulk ESS_tail R_hat
parameter
lp__ -837.621000 0.108199 2.998410 2.808790 -844.811000 -837.255000 -832.962000 800.308 1268.890 1.002780
logit_pi_know_group[1] -1.186370 0.016435 0.665201 0.602940 -2.679060 -1.115120 -0.057710 1958.970 1067.480 1.003470
logit_pi_know_group[2] 0.612810 0.012302 0.562335 0.570284 -0.454810 0.600089 1.745280 2128.260 1473.040 0.999152
logit_pi_know_group[3] 1.053900 0.014190 0.630915 0.560652 -0.109400 1.015180 2.427780 2303.210 1122.930 1.001530
logit_learn_group[1] -3.303280 0.008623 0.383281 0.364401 -4.125830 -3.265930 -2.636570 2149.770 1240.310 0.999253
logit_learn_group[2] -2.791580 0.008285 0.382229 0.385587 -3.613330 -2.774590 -2.097270 2205.470 1401.090 1.004720
logit_learn_group[3] -2.138990 0.007243 0.320278 0.303473 -2.830860 -2.117300 -1.555850 2049.870 1156.790 1.002120
logit_forget_group[1] -5.545130 0.058078 1.518500 1.055960 -10.029100 -5.232840 -3.670290 1321.660 575.351 1.004160
logit_forget_group[2] -3.417570 0.006945 0.331542 0.317573 -4.118930 -3.394710 -2.817380 2410.420 1507.380 0.999805
logit_forget_group[3] -3.798870 0.009233 0.395320 0.374668 -4.680470 -3.767470 -3.112510 2092.970 1201.110 1.000330
logit_guess_group[1] 0.073996 0.005242 0.249760 0.240632 -0.408548 0.075708 0.576238 2271.090 1558.560 1.006470
logit_guess_group[2] -0.358434 0.005954 0.298411 0.308658 -0.927045 -0.359170 0.234254 2559.260 1405.850 1.003330
logit_guess_group[3] -0.893599 0.008853 0.386720 0.381542 -1.662910 -0.884521 -0.134305 1955.850 1367.580 1.001680
logit_slip_group[1] -1.454010 0.005880 0.253594 0.244844 -1.980090 -1.448510 -0.957734 1905.790 1533.050 1.001990
logit_slip_group[2] -2.259540 0.006370 0.287611 0.281694 -2.854670 -2.251620 -1.730280 2159.370 1354.460 1.000970
logit_slip_group[3] -2.503820 0.006651 0.304977 0.290575 -3.153980 -2.475010 -1.967660 2326.990 1208.900 1.001500
pi_know[1] 0.252838 0.002373 0.109792 0.112821 0.064220 0.246917 0.485576 1958.970 1067.480 1.002510
pi_know[2] 0.638896 0.002605 0.121326 0.129587 0.388218 0.645676 0.851357 2128.270 1473.040 0.999152
pi_know[3] 0.724938 0.002259 0.112757 0.107481 0.472677 0.734031 0.918922 2303.210 1122.930 1.001410
learn[1] 0.037740 0.000283 0.013327 0.012925 0.015894 0.036758 0.066822 2149.770 1240.310 0.999471
learn[2] 0.061239 0.000448 0.021269 0.020952 0.026254 0.058713 0.109362 2205.480 1401.090 1.003750
learn[3] 0.109103 0.000666 0.030232 0.029169 0.055679 0.107427 0.174243 2049.870 1156.790 1.001100
forget[1] 0.007198 0.000141 0.006578 0.005373 0.000044 0.005310 0.024836 1321.650 575.351 1.001810
forget[2] 0.033308 0.000209 0.010347 0.010119 0.016002 0.032461 0.056392 2410.380 1507.380 0.999805
forget[3] 0.023470 0.000182 0.008690 0.008185 0.009189 0.022588 0.042594 2093.100 1201.110 1.000330
guess[1] 0.259102 0.000645 0.030671 0.029930 0.199631 0.259459 0.320101 2271.060 1558.560 1.006450
guess[2] 0.206578 0.000706 0.035417 0.036829 0.141762 0.205580 0.279148 2559.340 1405.850 1.003210
guess[3] 0.148244 0.000900 0.039277 0.039619 0.079686 0.146121 0.233238 1955.870 1367.580 1.001490
slip[1] 0.096202 0.000450 0.019514 0.018577 0.060655 0.095115 0.138666 1905.790 1533.050 1.002140
slip[2] 0.048679 0.000269 0.012363 0.012165 0.027220 0.047605 0.075275 2159.430 1354.460 1.000900
slip[3] 0.039134 0.000215 0.010497 0.010506 0.020468 0.038814 0.061320 2326.940 1208.900 1.001970

Numba-based point-estimate predictions#

predict(...) and predict_smoothed(...) use Numba-compiled routines internally for fast inference with Bayesian point estimates for the parameters.

[12]:
from stanbkt.utils.data_utils import ColumnNames

col_mapping = {
    ColumnNames.STUDENT_ID: "student_id",
    ColumnNames.PROBLEM_ID: "problem_id",
    ColumnNames.CORRECTNESS: "correct",
    ColumnNames.ORDER: "timestamp",
    ColumnNames.KC_ID: "kc_id",
    ColumnNames.GROUP: "group_id",
}
[13]:
predictions = model.predict(
    data_df,
    column_mapping=col_mapping,
    point_estimate="mean",
    parallel=True,
    fast_math=True,
)

predictions.head(20)
[13]:
kc_id student_id problem_id pKnow pCorrectness correct
0 kc_1 stu_0 prob_0 0.072479 0.216648 0
1 kc_1 stu_0 prob_2 0.079254 0.221594 0
2 kc_1 stu_0 prob_6 0.080166 0.222260 0
3 kc_1 stu_0 prob_7 0.080290 0.222350 1
4 kc_1 stu_0 prob_8 0.368654 0.432900 1
5 kc_1 stu_0 prob_10 0.774099 0.728935 0
6 kc_1 stu_0 prob_12 0.350429 0.419592 0
7 kc_1 stu_0 prob_17 0.129425 0.258227 0
8 kc_1 stu_0 prob_19 0.087297 0.227467 0
9 kc_1 stu_0 prob_20 0.081264 0.223061 0
10 kc_1 stu_0 prob_22 0.080439 0.222459 0
11 kc_1 stu_0 prob_24 0.080327 0.222377 0
12 kc_1 stu_0 prob_25 0.080311 0.222366 0
13 kc_1 stu_0 prob_33 0.080309 0.222365 0
14 kc_1 stu_0 prob_37 0.080309 0.222364 1
15 kc_1 stu_0 prob_40 0.368708 0.432938 0
16 kc_1 stu_0 prob_44 0.133983 0.261555 0
17 kc_1 stu_0 prob_50 0.087980 0.227965 0
18 kc_1 stu_0 prob_51 0.081358 0.223130 0
19 kc_1 stu_0 prob_59 0.080451 0.222468 0
[14]:
smoothed_predictions = model.predict_smoothed(
    data_df,
    column_mapping=col_mapping,
    point_estimate="mean",
    parallel=True,
    fast_math=True,
)

smoothed_predictions.head(20)
[14]:
kc_id student_id problem_id pKnow pCorrectness correct
0 kc_1 stu_0 prob_0 0.000107 0.163805 0
1 kc_1 stu_0 prob_2 0.000427 0.164039 0
2 kc_1 stu_0 prob_6 0.002743 0.165730 0
3 kc_1 stu_0 prob_7 0.019798 0.178182 1
4 kc_1 stu_0 prob_8 0.019798 0.178183 1
5 kc_1 stu_0 prob_10 0.002748 0.165733 0
6 kc_1 stu_0 prob_12 0.000434 0.164043 0
7 kc_1 stu_0 prob_17 0.000120 0.163814 0
8 kc_1 stu_0 prob_19 0.000077 0.163783 0
9 kc_1 stu_0 prob_20 0.000071 0.163779 0
10 kc_1 stu_0 prob_22 0.000071 0.163779 0
11 kc_1 stu_0 prob_24 0.000078 0.163784 0
12 kc_1 stu_0 prob_25 0.000124 0.163818 0
13 kc_1 stu_0 prob_33 0.000469 0.164070 0
14 kc_1 stu_0 prob_37 0.003011 0.165926 1
15 kc_1 stu_0 prob_40 0.000469 0.164070 0
16 kc_1 stu_0 prob_44 0.000124 0.163818 0
17 kc_1 stu_0 prob_50 0.000078 0.163784 0
18 kc_1 stu_0 prob_51 0.000071 0.163779 0
19 kc_1 stu_0 prob_59 0.000070 0.163778 0

Predictions and Visualizations#

As MultiBKT estimates a separate set of parameters for each group, posterior correctness can be visualized in two ways.

By default, plot_posterior_correctness(...) aggregates across groups and produces a single panel for the selected KC. If you want to inspect how posterior correctness differs by group, pass grouped=True. In that mode, the function creates one subplot per group.

[15]:
pred_post_draws = model.predict_posterior_draws(data_df, column_mapping=col_mapping)
17:00:56 - cmdstanpy - INFO - Chain [1] start processing
17:00:56 - cmdstanpy - INFO - Chain [2] start processing
17:00:56 - cmdstanpy - INFO - Chain [3] start processing
17:00:56 - cmdstanpy - INFO - Chain [4] start processing
17:00:57 - cmdstanpy - INFO - Chain [1] done processing
17:00:57 - cmdstanpy - INFO - Chain [2] done processing
17:00:57 - cmdstanpy - INFO - Chain [4] done processing
17:00:57 - cmdstanpy - INFO - Chain [3] done processing
17:00:57 - cmdstanpy - INFO - Chain [1] start processing
17:00:57 - cmdstanpy - INFO - Chain [2] start processing
17:00:57 - cmdstanpy - INFO - Chain [3] start processing
17:00:57 - cmdstanpy - INFO - Chain [4] start processing
17:00:57 - cmdstanpy - INFO - Chain [1] done processing
17:00:57 - cmdstanpy - INFO - Chain [4] done processing
17:00:57 - cmdstanpy - INFO - Chain [3] done processing
17:00:57 - cmdstanpy - INFO - Chain [2] done processing
17:00:58 - cmdstanpy - WARNING - Sample doesn't contain draws from warmup iterations, rerun sampler with "save_warmup=True".
17:00:59 - cmdstanpy - WARNING - Sample doesn't contain draws from warmup iterations, rerun sampler with "save_warmup=True".
[ ]:
from stanbkt.plot import plot_posterior_correctness

# correctness predictions (with predictive intervals)
axes = plot_posterior_correctness(
    posterior_pred_kc=pred_post_draws["kc_1"],
    data=data_df,
    column_mapping=col_mapping,
    kc="kc_1",
    type="preds",
    trajectory=True,
    frac=0.5,
)
../_images/examples_03_multi_BKT_example_23_0.png
[ ]:
axes_individual_groups = plot_posterior_correctness(
    posterior_pred_kc=pred_post_draws["kc_1"],
    data=data_df,
    column_mapping=col_mapping,
    kc="kc_1",
    type="preds",
    trajectory=True,
    grouped=True,
    frac=0.5,
)
../_images/examples_03_multi_BKT_example_24_0.png