diff --git a/src/imagination/pco/pco_map.py b/src/imagination/pco/pco_map.py index ff5b801541d..ce1d35313bc 100644 --- a/src/imagination/pco/pco_map.py +++ b/src/imagination/pco/pco_map.py @@ -265,15 +265,16 @@ enum_map(RM_ELEM.t, F_MASKW0, [ ], pass_zero=['e0', 'e1', 'e2', 'e3']) class OpRef(object): - def __init__(self, ref_type, index): + def __init__(self, ref_type, index, mods): self.type = ref_type self.index = index + self.mods = mods -def SRC(index): - return OpRef('src', index) +def SRC(index, mods=[]): + return OpRef('src', index, mods) -def DEST(index): - return OpRef('dest', index) +def DEST(index, mods=[]): + return OpRef('dest', index, mods) encode_maps = {} group_maps = {} @@ -331,6 +332,7 @@ def encode_map(op, encodings, op_ref_maps): elif isinstance(val_spec, tuple): mod, _origin = val_spec assert isinstance(_origin, OpRef) + assert not _origin.mods origin = f'{_origin.type}[{_origin.index}]' if isinstance(mod, RefMod): @@ -386,6 +388,7 @@ def encode_map(op, encodings, op_ref_maps): mod, _origin, cond = isa_op_cond assert isinstance(_origin, OpRef) + assert not _origin.mods origin = f'{_origin.type}[{_origin.index}]' assert isinstance(mod, RefMod) @@ -399,6 +402,7 @@ def encode_map(op, encodings, op_ref_maps): elif isinstance(isa_op_cond[0], str): mod, _origin = isa_op_cond assert isinstance(_origin, OpRef) + assert not _origin.mods origin = f'{_origin.type}[{_origin.index}]' conds_variant += f'{mod}({{1}}->{origin})' else: @@ -509,6 +513,7 @@ def group_map(op, hdr, enc_ops, srcs=[], iss=[], dests=[]): elif isinstance(val_spec, tuple): mod, _origin = val_spec assert isinstance(_origin, OpRef) + assert not _origin.mods origin = f'{_origin.type}[{_origin.index}]' assert isinstance(mod, str) and isinstance(origin, str) hdr_mappings.append(f'{{}}->hdr.{hdr_field} = {mod}({{}}->{origin});') @@ -565,8 +570,17 @@ def group_map(op, hdr, enc_ops, srcs=[], iss=[], dests=[]): io = IO.enum.elems[_ref] enc_mapping += f', pco_ref_io({io.cname})' elif isinstance(_ref, OpRef): - origin = f'{_ref.type}[{_ref.index}]' - enc_mapping += f', {{1}}->{origin}' + origin = f'{{1}}->{_ref.type}[{_ref.index}]' + + if _ref.type == 'src': + assert all(mod in enc_op.src_mods[_ref.index] for mod in _ref.mods) + else: + assert all(mod in enc_op.dest_mods[_ref.index] for mod in _ref.mods) + + for mod in _ref.mods: + origin = f'pco_ref_{mod.t.tname}({origin})' + + enc_mapping += f', {origin}' elif isinstance(_ref, str): if _ref == '_': enc_mapping += ', pco_ref_null()' @@ -621,6 +635,7 @@ def group_map(op, hdr, enc_ops, srcs=[], iss=[], dests=[]): phase = OP_PHASE.enum.elems[_phase].cname assert isinstance(val_spec, OpRef) + assert not val_spec.mods origin = f'{val_spec.type}[{val_spec.index}]' assert _io in IO.enum.elems.keys() @@ -660,6 +675,7 @@ def group_map(op, hdr, enc_ops, srcs=[], iss=[], dests=[]): phase = OP_PHASE.enum.elems[_phase].cname assert isinstance(val_spec, OpRef) + assert not val_spec.mods origin = f'{val_spec.type}[{val_spec.index}]' if _io in IO.enum.elems.keys():