predictor
pypythia.predictor.DifficultyPredictor
Class structure for the trained difficulty predictor.
This class provides methods for predicting the difficulty and plot the shapley values for an MSA.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_file
|
Path
|
Path to the trained difficulty predictor model. Defaults to the latest model shipped with PyPythia. Note that this model file must be in the LightGBM .txt format. |
DEFAULT_MODEL_FILE
|
features
|
list[str]
|
Names of the features the predictor was trained with. Defaults to None. In this case, the features are inferred from the model file. |
None
|
Attributes:
Name | Type | Description |
---|---|---|
predictor |
Loaded trained predictor. |
|
features |
Names of the features the predictor was trained with. |
Source code in pypythia/predictor.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
|
plot_shapley_values(query)
Plot the shapley values for the first MSA in the given query dataframe.
Please read our notes on SHAP values in the documentation to understand the plot.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query
|
DataFrame
|
DataFrame containing the features for which to plot the shapley values. |
required |
Returns:
Type | Description |
---|---|
Figure
|
A matplotlib Figure object containing the waterfall plot of the shapley values for the first MSA in the query. |
Source code in pypythia/predictor.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
|
predict(query)
Predict the difficulty for a set of MSAs defined by rows in the given query dataframe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query
|
DataFrame
|
DataFrame containing the features for which to predict the difficulty. Each row in the DataFrame corresponds to a single MSA and the columns correspond to the features. |
required |
Returns:
Type | Description |
---|---|
NDArray[float64]
|
A numpy array of predicted difficulties for the provided set of MSAs in float64 format. |
NDArray[float64]
|
The difficulties are values in the range [0, 1] where higher values indicate higher difficulty. |
Source code in pypythia/predictor.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|