Lightgbm: [Feature] Faster prediction by converting tree model into code

Created on 27 Apr 2017  路  13Comments  路  Source: microsoft/LightGBM

  1. convert tree models into many if ... else ... codes.
  2. multi-threading for tree models. So it is still very fast for prediction of one instance.
feature request help wanted

Most helpful comment

I wrote some python codes like below, it converts dumped trees into c++/java code:

model_json = gbm.dump_model()

def parseOneTree(root, index, array_type='double', return_type='double'):
    def ifElse(node):
        if 'leaf_index' in node:
            return 'return ' + str(node['leaf_value']) + ';'
        else:
            condition = 'arr[' + str(node['split_feature']) + ']'
            if node['decision_type'] == 'no_greater':
                condition += ' <= ' + str(node['threshold'])
            else:
                condition += ' == ' + str(node['threshold'])
            left = ifElse(node['left_child'])
            right = ifElse(node['right_child'])
            return 'if ( ' + condition + ' ) { ' + left + ' } else { ' + right + ' }'
    return return_type + ' predictTree' + str(index) + '(' + array_type + '[] arr) { ' + ifElse(root) + ' }'

def parseAllTrees(trees, array_type='double', return_type='double'):
    return '\n\n'.join([parseOneTree(tree['tree_structure'], idx, array_type, return_type) for idx, tree in enumerate(trees)]) \
        + '\n\n' + return_type + ' predict(' + array_type + '[] arr) { ' \
        + 'return ' + ' + '.join(['predictTree' + str(i) + '(arr)' for i in range(len(trees))]) + ';' \
        + '}'

with open('if.else', 'w+') as f:
    f.write(parseAllTrees(model_json["tree_info"]))

I made a simple test with java, the result code is as follow, the prediction result is same as inner predict function:

double predictTree0(double[] arr) { if ( arr[0] <= 0 ) { return -0.9431; } else { return -0.9431; } }

double predictTree1(double[] arr) { if ( arr[13] <= -30.5 ) { if ( arr[21] <= 45.5 ) { if ( arr[22] <= 16.5 ) { if ( arr[9] <= 22.5 ) { if ( arr[6] <= 2.5 ) { if ( arr[23] <= 43.5 ) { if ( arr[15] <= -30.5 ) { return 0.638946; } else { if ( arr[13] <= -33.5 ) { if ( arr[22] <= -4.5 ) { if ( arr[21] <= 28.5 ) { return -0.0531971; } else { return -0.877039; } } else { return 0.397859; } } else { if ( arr[19] <= 6.5 ) { return -0.0255723; } else { return 1.23515; } } } } else { return 1.2543; } } else { if ( arr[3] <= 39.5 ) { if ( arr[7] <= -11.5 ) { if ( arr[17] <= -29.5 ) { return 1.14063; } else { if ( arr[9] <= 7.5 ) { return -0.132012; } else { return 0.877155; } } } else { if ( arr[17] <= -35.5 ) { return -1.04375; } else { return -0.169512; } } } else { return -0.961178; } } } else { return -0.233391; } } else { if ( arr[20] <= -23.5 ) { if ( arr[25] <= 30.5 ) { return -0.281633; } else { return 0.554298; } } else { if ( arr[0] <= 5.5 ) { if ( arr[3] <= 4.5 ) { if ( arr[26] <= 23.5 ) { return 0.575822; } else { return 1.45787; } } else { return 0.186044; } } else { if ( arr[23] <= -2.5 ) { if ( arr[11] <= 33.5 ) { if ( arr[25] <= 14.5 ) { return 0.597155; } else { return -0.420236; } } else { return 1.21965; } } else { return -0.242002; } } } } } else { if ( arr[5] <= 5.5 ) { return 0.128822; } else { return 1.07283; } } } else { if ( arr[25] <= 30.5 ) { if ( arr[19] <= -44.5 ) { return 0.238409; } else { if ( arr[4] <= 4.5 ) { if ( arr[24] <= -49.5 ) { return -0.958614; } else { if ( arr[24] <= -34.5 ) { return 0.289861; } else { return 0.0101188; } } } else { return -0.0913412; } } } else { return -0.122343; } } }

double predictTree2(double[] arr) { if ( arr[14] <= 48.5 ) { if ( arr[13] <= -30.5 ) { if ( arr[21] <= 45.5 ) { if ( arr[3] <= 46.5 ) { return 0.0707513; } else { return -0.566098; } } else { return 0.617633; } } else { if ( arr[16] <= -25.5 ) { if ( arr[21] <= 19.5 ) { if ( arr[9] <= 6.5 ) { if ( arr[12] <= -45.5 ) { return 0.943692; } else { if ( arr[4] <= -8.5 ) { return 0.290743; } else { if ( arr[19] <= -23.5 ) { if ( arr[26] <= -8.5 ) { return -1.06871; } else { if ( arr[0] <= -0.5 ) { return -0.759557; } else { return 0.314001; } } } else { if ( arr[7] <= -41.5 ) { return 1.07732; } else { if ( arr[18] <= 30.5 ) { return 0.110816; } else { return -0.502498; } } } } } } else { if ( arr[5] <= 3.5 ) { return 0.105686; } else { if ( arr[11] <= 33.5 ) { return -0.548442; } else { if ( arr[4] <= 0.5 ) { return 1.04067; } else { return -0.412117; } } } } } else { if ( arr[16] <= -39.5 ) { if ( arr[8] <= 14.5 ) { if ( arr[8] <= 5.5 ) { if ( arr[4] <= 27.5 ) { return -0.224228; } else { return 0.725575; } } else { return 1.11628; } } else { return -0.430015; } } else { if ( arr[7] <= 30.5 ) { if ( arr[18] <= -2.5 ) { if ( arr[12] <= 27.5 ) { return -0.243358; } else { return 0.635237; } } else { return 0.534673; } } else { return 0.944614; } } } } else { if ( arr[25] <= -30.5 ) { if ( arr[4] <= 26.5 ) { if ( arr[13] <= -24.5 ) { return 0.692067; } else { if ( arr[26] <= -25.5 ) { return -0.127491; } else { return 0.190598; } } } else { if ( arr[26] <= 28.5 ) { return -0.100314; } else { return -0.796962; } } } else { return -0.0923691; } } } } else { if ( arr[26] <= -7.5 ) { return 0.931673; } else { return -0.00291374; } } }

double predictTree3(double[] arr) { if ( arr[1] <= 48.5 ) { if ( arr[14] <= 48.5 ) { if ( arr[10] <= 29.5 ) { if ( arr[8] <= -46.5 ) { if ( arr[10] <= -24.5 ) { return 0.74522; } else { return 0.13412; } } else { if ( arr[4] <= -0.5 ) { if ( arr[5] <= -36.5 ) { if ( arr[23] <= -31.5 ) { return 0.787587; } else { if ( arr[8] <= -5.5 ) { if ( arr[5] <= -47.5 ) { return -0.226148; } else { if ( arr[15] <= -9.5 ) { return 0.148057; } else { return 1.12421; } } } else { if ( arr[2] <= -31.5 ) { return -0.656214; } else { if ( arr[0] <= -28.5 ) { return -0.520197; } else { return 0.383778; } } } } } else { if ( arr[23] <= -47.5 ) { return -0.449624; } else { if ( arr[16] <= 7.5 ) { return 0.106182; } else { return -0.0755566; } } } } else { if ( arr[14] <= 30.5 ) { return -0.102649; } else { if ( arr[5] <= -44.5 ) { return -0.522098; } else { return 0.171904; } } } } } else { if ( arr[2] <= -21.5 ) { if ( arr[20] <= 20.5 ) { return -0.419594; } else { if ( arr[21] <= -25.5 ) { return -0.734804; } else { if ( arr[8] <= -25.5 ) { return -0.503601; } else { return 0.584065; } } } } else { if ( arr[2] <= -3.5 ) { if ( arr[7] <= -38.5 ) { return -0.627834; } else { if ( arr[21] <= 5.5 ) { if ( arr[8] <= 1.5 ) { return -0.380093; } else { return 0.370773; } } else { if ( arr[26] <= -9.5 ) { return 1.07691; } else { return 0.278066; } } } } else { if ( arr[11] <= 27.5 ) { return -0.188583; } else { return 0.157865; } } } } } else { if ( arr[5] <= 8.5 ) { if ( arr[0] <= -10.5 ) { return 0.00207745; } else { return 1.29174; } } else { return -0.132605; } } } else { if ( arr[6] <= 18.5 ) { return 0.791693; } else { return -0.182431; } } }

double predict(double[] arr) { return predictTree0(arr) + predictTree1(arr) + predictTree2(arr) + predictTree3(arr);}

All 13 comments

what are the possible values of "decision_type"?

@wxchan only consider 2 types now.

I wrote some python codes like below, it converts dumped trees into c++/java code:

model_json = gbm.dump_model()

def parseOneTree(root, index, array_type='double', return_type='double'):
    def ifElse(node):
        if 'leaf_index' in node:
            return 'return ' + str(node['leaf_value']) + ';'
        else:
            condition = 'arr[' + str(node['split_feature']) + ']'
            if node['decision_type'] == 'no_greater':
                condition += ' <= ' + str(node['threshold'])
            else:
                condition += ' == ' + str(node['threshold'])
            left = ifElse(node['left_child'])
            right = ifElse(node['right_child'])
            return 'if ( ' + condition + ' ) { ' + left + ' } else { ' + right + ' }'
    return return_type + ' predictTree' + str(index) + '(' + array_type + '[] arr) { ' + ifElse(root) + ' }'

def parseAllTrees(trees, array_type='double', return_type='double'):
    return '\n\n'.join([parseOneTree(tree['tree_structure'], idx, array_type, return_type) for idx, tree in enumerate(trees)]) \
        + '\n\n' + return_type + ' predict(' + array_type + '[] arr) { ' \
        + 'return ' + ' + '.join(['predictTree' + str(i) + '(arr)' for i in range(len(trees))]) + ';' \
        + '}'

with open('if.else', 'w+') as f:
    f.write(parseAllTrees(model_json["tree_info"]))

I made a simple test with java, the result code is as follow, the prediction result is same as inner predict function:

double predictTree0(double[] arr) { if ( arr[0] <= 0 ) { return -0.9431; } else { return -0.9431; } }

double predictTree1(double[] arr) { if ( arr[13] <= -30.5 ) { if ( arr[21] <= 45.5 ) { if ( arr[22] <= 16.5 ) { if ( arr[9] <= 22.5 ) { if ( arr[6] <= 2.5 ) { if ( arr[23] <= 43.5 ) { if ( arr[15] <= -30.5 ) { return 0.638946; } else { if ( arr[13] <= -33.5 ) { if ( arr[22] <= -4.5 ) { if ( arr[21] <= 28.5 ) { return -0.0531971; } else { return -0.877039; } } else { return 0.397859; } } else { if ( arr[19] <= 6.5 ) { return -0.0255723; } else { return 1.23515; } } } } else { return 1.2543; } } else { if ( arr[3] <= 39.5 ) { if ( arr[7] <= -11.5 ) { if ( arr[17] <= -29.5 ) { return 1.14063; } else { if ( arr[9] <= 7.5 ) { return -0.132012; } else { return 0.877155; } } } else { if ( arr[17] <= -35.5 ) { return -1.04375; } else { return -0.169512; } } } else { return -0.961178; } } } else { return -0.233391; } } else { if ( arr[20] <= -23.5 ) { if ( arr[25] <= 30.5 ) { return -0.281633; } else { return 0.554298; } } else { if ( arr[0] <= 5.5 ) { if ( arr[3] <= 4.5 ) { if ( arr[26] <= 23.5 ) { return 0.575822; } else { return 1.45787; } } else { return 0.186044; } } else { if ( arr[23] <= -2.5 ) { if ( arr[11] <= 33.5 ) { if ( arr[25] <= 14.5 ) { return 0.597155; } else { return -0.420236; } } else { return 1.21965; } } else { return -0.242002; } } } } } else { if ( arr[5] <= 5.5 ) { return 0.128822; } else { return 1.07283; } } } else { if ( arr[25] <= 30.5 ) { if ( arr[19] <= -44.5 ) { return 0.238409; } else { if ( arr[4] <= 4.5 ) { if ( arr[24] <= -49.5 ) { return -0.958614; } else { if ( arr[24] <= -34.5 ) { return 0.289861; } else { return 0.0101188; } } } else { return -0.0913412; } } } else { return -0.122343; } } }

double predictTree2(double[] arr) { if ( arr[14] <= 48.5 ) { if ( arr[13] <= -30.5 ) { if ( arr[21] <= 45.5 ) { if ( arr[3] <= 46.5 ) { return 0.0707513; } else { return -0.566098; } } else { return 0.617633; } } else { if ( arr[16] <= -25.5 ) { if ( arr[21] <= 19.5 ) { if ( arr[9] <= 6.5 ) { if ( arr[12] <= -45.5 ) { return 0.943692; } else { if ( arr[4] <= -8.5 ) { return 0.290743; } else { if ( arr[19] <= -23.5 ) { if ( arr[26] <= -8.5 ) { return -1.06871; } else { if ( arr[0] <= -0.5 ) { return -0.759557; } else { return 0.314001; } } } else { if ( arr[7] <= -41.5 ) { return 1.07732; } else { if ( arr[18] <= 30.5 ) { return 0.110816; } else { return -0.502498; } } } } } } else { if ( arr[5] <= 3.5 ) { return 0.105686; } else { if ( arr[11] <= 33.5 ) { return -0.548442; } else { if ( arr[4] <= 0.5 ) { return 1.04067; } else { return -0.412117; } } } } } else { if ( arr[16] <= -39.5 ) { if ( arr[8] <= 14.5 ) { if ( arr[8] <= 5.5 ) { if ( arr[4] <= 27.5 ) { return -0.224228; } else { return 0.725575; } } else { return 1.11628; } } else { return -0.430015; } } else { if ( arr[7] <= 30.5 ) { if ( arr[18] <= -2.5 ) { if ( arr[12] <= 27.5 ) { return -0.243358; } else { return 0.635237; } } else { return 0.534673; } } else { return 0.944614; } } } } else { if ( arr[25] <= -30.5 ) { if ( arr[4] <= 26.5 ) { if ( arr[13] <= -24.5 ) { return 0.692067; } else { if ( arr[26] <= -25.5 ) { return -0.127491; } else { return 0.190598; } } } else { if ( arr[26] <= 28.5 ) { return -0.100314; } else { return -0.796962; } } } else { return -0.0923691; } } } } else { if ( arr[26] <= -7.5 ) { return 0.931673; } else { return -0.00291374; } } }

double predictTree3(double[] arr) { if ( arr[1] <= 48.5 ) { if ( arr[14] <= 48.5 ) { if ( arr[10] <= 29.5 ) { if ( arr[8] <= -46.5 ) { if ( arr[10] <= -24.5 ) { return 0.74522; } else { return 0.13412; } } else { if ( arr[4] <= -0.5 ) { if ( arr[5] <= -36.5 ) { if ( arr[23] <= -31.5 ) { return 0.787587; } else { if ( arr[8] <= -5.5 ) { if ( arr[5] <= -47.5 ) { return -0.226148; } else { if ( arr[15] <= -9.5 ) { return 0.148057; } else { return 1.12421; } } } else { if ( arr[2] <= -31.5 ) { return -0.656214; } else { if ( arr[0] <= -28.5 ) { return -0.520197; } else { return 0.383778; } } } } } else { if ( arr[23] <= -47.5 ) { return -0.449624; } else { if ( arr[16] <= 7.5 ) { return 0.106182; } else { return -0.0755566; } } } } else { if ( arr[14] <= 30.5 ) { return -0.102649; } else { if ( arr[5] <= -44.5 ) { return -0.522098; } else { return 0.171904; } } } } } else { if ( arr[2] <= -21.5 ) { if ( arr[20] <= 20.5 ) { return -0.419594; } else { if ( arr[21] <= -25.5 ) { return -0.734804; } else { if ( arr[8] <= -25.5 ) { return -0.503601; } else { return 0.584065; } } } } else { if ( arr[2] <= -3.5 ) { if ( arr[7] <= -38.5 ) { return -0.627834; } else { if ( arr[21] <= 5.5 ) { if ( arr[8] <= 1.5 ) { return -0.380093; } else { return 0.370773; } } else { if ( arr[26] <= -9.5 ) { return 1.07691; } else { return 0.278066; } } } } else { if ( arr[11] <= 27.5 ) { return -0.188583; } else { return 0.157865; } } } } } else { if ( arr[5] <= 8.5 ) { if ( arr[0] <= -10.5 ) { return 0.00207745; } else { return 1.29174; } } else { return -0.132605; } } } else { if ( arr[6] <= 18.5 ) { return 0.791693; } else { return -0.182431; } } }

double predict(double[] arr) { return predictTree0(arr) + predictTree1(arr) + predictTree2(arr) + predictTree3(arr);}

@guolinke not understand. Is there any other decision_type besides no_greater?

I chose java mainly because it can be used in hadoop or spark.

@wxchan
for categorical feature, it may is if (x ==1) (decision_type=1) .

I wound like to have cpp and java version. (only few difference).

Some other suggestions:

  1. I think it is better to use model file directly, not the model json.
  2. Multi-threading support. (put these prediction into function arrary, and use openmp) .
  3. support predict_type. (raw_score, normal, leaf_index) .

@guolinke decision_type=0 means no_greater, decision_type=1 means is?

do you mean to write this parser in c++/java or generated codes in c++/java(like what I do now)?

@wxchan
yes.

use python to generate these code are ok.
I mean generate both cpp and java codes.

@guolinke what's your future plan on this feature, after #469 ?

@wxchan , I think it is good enough of #469 . Do you have any improvement ?

@guolinke Not really yet. I just wonder how this feature works on bigger datasets like #396 and #435 . Maybe make it support in python-package if it really helps.

@wxchan
you mean that use python to generate the code ?

Actually, I think it will be better if we can auto replace the code of predict in gbdt.cpp.

I mean expose it to c_api.

Actually recompile with long hard-coded predict function takes a lot of time on my machine, I don't know if it happens on other machines.

Was this page helpful?
0 / 5 - 0 ratings