if ... else ... codes. what are the possible values of "decision_type"?
@wxchan only consider 2 types now.
I wrote some python codes like below, it converts dumped trees into c++/java code:
model_json = gbm.dump_model()
def parseOneTree(root, index, array_type='double', return_type='double'):
def ifElse(node):
if 'leaf_index' in node:
return 'return ' + str(node['leaf_value']) + ';'
else:
condition = 'arr[' + str(node['split_feature']) + ']'
if node['decision_type'] == 'no_greater':
condition += ' <= ' + str(node['threshold'])
else:
condition += ' == ' + str(node['threshold'])
left = ifElse(node['left_child'])
right = ifElse(node['right_child'])
return 'if ( ' + condition + ' ) { ' + left + ' } else { ' + right + ' }'
return return_type + ' predictTree' + str(index) + '(' + array_type + '[] arr) { ' + ifElse(root) + ' }'
def parseAllTrees(trees, array_type='double', return_type='double'):
return '\n\n'.join([parseOneTree(tree['tree_structure'], idx, array_type, return_type) for idx, tree in enumerate(trees)]) \
+ '\n\n' + return_type + ' predict(' + array_type + '[] arr) { ' \
+ 'return ' + ' + '.join(['predictTree' + str(i) + '(arr)' for i in range(len(trees))]) + ';' \
+ '}'
with open('if.else', 'w+') as f:
f.write(parseAllTrees(model_json["tree_info"]))
I made a simple test with java, the result code is as follow, the prediction result is same as inner predict function:
double predictTree0(double[] arr) { if ( arr[0] <= 0 ) { return -0.9431; } else { return -0.9431; } }
double predictTree1(double[] arr) { if ( arr[13] <= -30.5 ) { if ( arr[21] <= 45.5 ) { if ( arr[22] <= 16.5 ) { if ( arr[9] <= 22.5 ) { if ( arr[6] <= 2.5 ) { if ( arr[23] <= 43.5 ) { if ( arr[15] <= -30.5 ) { return 0.638946; } else { if ( arr[13] <= -33.5 ) { if ( arr[22] <= -4.5 ) { if ( arr[21] <= 28.5 ) { return -0.0531971; } else { return -0.877039; } } else { return 0.397859; } } else { if ( arr[19] <= 6.5 ) { return -0.0255723; } else { return 1.23515; } } } } else { return 1.2543; } } else { if ( arr[3] <= 39.5 ) { if ( arr[7] <= -11.5 ) { if ( arr[17] <= -29.5 ) { return 1.14063; } else { if ( arr[9] <= 7.5 ) { return -0.132012; } else { return 0.877155; } } } else { if ( arr[17] <= -35.5 ) { return -1.04375; } else { return -0.169512; } } } else { return -0.961178; } } } else { return -0.233391; } } else { if ( arr[20] <= -23.5 ) { if ( arr[25] <= 30.5 ) { return -0.281633; } else { return 0.554298; } } else { if ( arr[0] <= 5.5 ) { if ( arr[3] <= 4.5 ) { if ( arr[26] <= 23.5 ) { return 0.575822; } else { return 1.45787; } } else { return 0.186044; } } else { if ( arr[23] <= -2.5 ) { if ( arr[11] <= 33.5 ) { if ( arr[25] <= 14.5 ) { return 0.597155; } else { return -0.420236; } } else { return 1.21965; } } else { return -0.242002; } } } } } else { if ( arr[5] <= 5.5 ) { return 0.128822; } else { return 1.07283; } } } else { if ( arr[25] <= 30.5 ) { if ( arr[19] <= -44.5 ) { return 0.238409; } else { if ( arr[4] <= 4.5 ) { if ( arr[24] <= -49.5 ) { return -0.958614; } else { if ( arr[24] <= -34.5 ) { return 0.289861; } else { return 0.0101188; } } } else { return -0.0913412; } } } else { return -0.122343; } } }
double predictTree2(double[] arr) { if ( arr[14] <= 48.5 ) { if ( arr[13] <= -30.5 ) { if ( arr[21] <= 45.5 ) { if ( arr[3] <= 46.5 ) { return 0.0707513; } else { return -0.566098; } } else { return 0.617633; } } else { if ( arr[16] <= -25.5 ) { if ( arr[21] <= 19.5 ) { if ( arr[9] <= 6.5 ) { if ( arr[12] <= -45.5 ) { return 0.943692; } else { if ( arr[4] <= -8.5 ) { return 0.290743; } else { if ( arr[19] <= -23.5 ) { if ( arr[26] <= -8.5 ) { return -1.06871; } else { if ( arr[0] <= -0.5 ) { return -0.759557; } else { return 0.314001; } } } else { if ( arr[7] <= -41.5 ) { return 1.07732; } else { if ( arr[18] <= 30.5 ) { return 0.110816; } else { return -0.502498; } } } } } } else { if ( arr[5] <= 3.5 ) { return 0.105686; } else { if ( arr[11] <= 33.5 ) { return -0.548442; } else { if ( arr[4] <= 0.5 ) { return 1.04067; } else { return -0.412117; } } } } } else { if ( arr[16] <= -39.5 ) { if ( arr[8] <= 14.5 ) { if ( arr[8] <= 5.5 ) { if ( arr[4] <= 27.5 ) { return -0.224228; } else { return 0.725575; } } else { return 1.11628; } } else { return -0.430015; } } else { if ( arr[7] <= 30.5 ) { if ( arr[18] <= -2.5 ) { if ( arr[12] <= 27.5 ) { return -0.243358; } else { return 0.635237; } } else { return 0.534673; } } else { return 0.944614; } } } } else { if ( arr[25] <= -30.5 ) { if ( arr[4] <= 26.5 ) { if ( arr[13] <= -24.5 ) { return 0.692067; } else { if ( arr[26] <= -25.5 ) { return -0.127491; } else { return 0.190598; } } } else { if ( arr[26] <= 28.5 ) { return -0.100314; } else { return -0.796962; } } } else { return -0.0923691; } } } } else { if ( arr[26] <= -7.5 ) { return 0.931673; } else { return -0.00291374; } } }
double predictTree3(double[] arr) { if ( arr[1] <= 48.5 ) { if ( arr[14] <= 48.5 ) { if ( arr[10] <= 29.5 ) { if ( arr[8] <= -46.5 ) { if ( arr[10] <= -24.5 ) { return 0.74522; } else { return 0.13412; } } else { if ( arr[4] <= -0.5 ) { if ( arr[5] <= -36.5 ) { if ( arr[23] <= -31.5 ) { return 0.787587; } else { if ( arr[8] <= -5.5 ) { if ( arr[5] <= -47.5 ) { return -0.226148; } else { if ( arr[15] <= -9.5 ) { return 0.148057; } else { return 1.12421; } } } else { if ( arr[2] <= -31.5 ) { return -0.656214; } else { if ( arr[0] <= -28.5 ) { return -0.520197; } else { return 0.383778; } } } } } else { if ( arr[23] <= -47.5 ) { return -0.449624; } else { if ( arr[16] <= 7.5 ) { return 0.106182; } else { return -0.0755566; } } } } else { if ( arr[14] <= 30.5 ) { return -0.102649; } else { if ( arr[5] <= -44.5 ) { return -0.522098; } else { return 0.171904; } } } } } else { if ( arr[2] <= -21.5 ) { if ( arr[20] <= 20.5 ) { return -0.419594; } else { if ( arr[21] <= -25.5 ) { return -0.734804; } else { if ( arr[8] <= -25.5 ) { return -0.503601; } else { return 0.584065; } } } } else { if ( arr[2] <= -3.5 ) { if ( arr[7] <= -38.5 ) { return -0.627834; } else { if ( arr[21] <= 5.5 ) { if ( arr[8] <= 1.5 ) { return -0.380093; } else { return 0.370773; } } else { if ( arr[26] <= -9.5 ) { return 1.07691; } else { return 0.278066; } } } } else { if ( arr[11] <= 27.5 ) { return -0.188583; } else { return 0.157865; } } } } } else { if ( arr[5] <= 8.5 ) { if ( arr[0] <= -10.5 ) { return 0.00207745; } else { return 1.29174; } } else { return -0.132605; } } } else { if ( arr[6] <= 18.5 ) { return 0.791693; } else { return -0.182431; } } }
double predict(double[] arr) { return predictTree0(arr) + predictTree1(arr) + predictTree2(arr) + predictTree3(arr);}
@guolinke not understand. Is there any other decision_type besides no_greater?
I chose java mainly because it can be used in hadoop or spark.
@wxchan
for categorical feature, it may is if (x ==1) (decision_type=1) .
I wound like to have cpp and java version. (only few difference).
Some other suggestions:
@guolinke decision_type=0 means no_greater, decision_type=1 means is?
do you mean to write this parser in c++/java or generated codes in c++/java(like what I do now)?
@wxchan
yes.
use python to generate these code are ok.
I mean generate both cpp and java codes.
@guolinke what's your future plan on this feature, after #469 ?
@wxchan , I think it is good enough of #469 . Do you have any improvement ?
@guolinke Not really yet. I just wonder how this feature works on bigger datasets like #396 and #435 . Maybe make it support in python-package if it really helps.
@wxchan
you mean that use python to generate the code ?
Actually, I think it will be better if we can auto replace the code of predict in gbdt.cpp.
I mean expose it to c_api.
Actually recompile with long hard-coded predict function takes a lot of time on my machine, I don't know if it happens on other machines.
Most helpful comment
I wrote some python codes like below, it converts dumped trees into c++/java code:
I made a simple test with java, the result code is as follow, the prediction result is same as inner predict function: