Keras: [Request] model.summary() returns output as string

Created on 22 Mar 2017  路  10Comments  路  Source: keras-team/keras

Have the output of model.summary() returns output as string. Using to_json() or to_yaml() is not user friendly for quickly browsing a model, the output of .summary() is at the correct level. Having .summary() return a string is useful for logging.

Please thumbs up this if you'd also like.

Most helpful comment

You can pass your own print_fn function to summary:

model.summary(print_fn=my_function)

print_fn is called repeatedly with every line of the summary. Easy to use a custom function to store a string record.

All 10 comments

There was actually a related pull request in the past (#3479)

@andreh7 @fchollet can weigh in. Perhaps in keras 2.0 this makes more sense to have. I don't see why having .summary() as print() makes any sense. You can't log it easily.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

discussing with @vlimant yesterday I remembered having used the following (not thread safe) workaround (until we hear from the project maintainers):

import cStringIO as StringIO
import sys

# keep track of the original sys.stdout
origStdout = sys.stdout

# replace sys.stdout temporarily with our own buffer
outputBuf  = StringIO.StringIO()
sys.stdout = outputBuf

# print the model summary
model.summary()

# put back the original stdout
std.stdout = origStdout

# get the model summary as a string
modelDescription = outputBuf.getvalue()

Any update on this?

You can pass your own print_fn function to summary:

model.summary(print_fn=my_function)

print_fn is called repeatedly with every line of the summary. Easy to use a custom function to store a string record.

for python3 :
from io import StringIO
import sys

origStdout = sys.stdout
outputBuf = StringIO()
sys.stdout = outputBuf
model.summary()
sys.stdout = origStdout
Get_model_description = outputBuf.getvalue()

print(Get_model_description)

A little less hacky:

tmp_smry = StringIO()
model.summary(print_fn=lambda x: tmp_smry.write(x + '\n'))
summary = tmp_smry.getvalue()

Same as above without extra imports:

def summary(model: tf.keras.Model) -> str:
  summary = []
  model.summary(print_fn=lambda x: summary.append(x))
  return '\n'.join(summary)
Was this page helpful?
0 / 5 - 0 ratings