I'm runing a DLRM model to inference shape, but I found that Gemm op can't inference output shape, the reason is that Gemm op input 1(which is input B) is in graph initializer list, and not in graph input list. This model is converted from pytorch. I want to know, why input TypeProto list have no initializer information, is it reasonable?
initializers do not have to be included in graph inputs.
The point here should be, the shape inference implementation should be able to use initializer data (shape).
This issue should be the same as #2655
@TMVector @PenghuiCheng any of you would like to offer a fix for this gap please?
I haven't had time to work on ONNX lately (sorry for the late reply), but I have been using this method as a stopgap at work.
def add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
It should be resolved by https://github.com/onnx/onnx/pull/2901. Please reopen it if you still encounter this issue. Thanks.
Most helpful comment
I haven't had time to work on ONNX lately (sorry for the late reply), but I have been using this method as a stopgap at work.