Source code for asm_analyser.architectures.arm.branch_pred

'''Implements everything required (including different strategies) for
branch prediction simulation with ARM assembly.
'''
import re
from typing import List, Dict
from asm_analyser import branch_pred
from asm_analyser.blocks.code_block import CodeBlock


[docs]class ArmBranchPredictor(branch_pred.BranchPredictor): '''Implements the BranchPredictor class for ARM assembly. '''
[docs] def one_bit(self) -> str: '''Branch prediction using one bit (saturating counter). Returns ------- str C code containing all necessary elements for this branch predictor. ''' result = '' branch_count = 0 branch_index = 0 for line in self.c_code.splitlines(): if '//BRANCHTAKEN' in line: branch_count += 1 for line in self.c_code.splitlines(): if '//BPDEFS' in line: result += f'uint8_t branch_bits[{branch_count}];\n' result += f'int cond_branches[{branch_count}];\n' result += f'int mispredictions[{branch_count}];\n' elif '//BPINIT' in line: if branch_count > 0: result += '.branch_bits = {0}, .cond_branches = {0}, ' result += '.mispredictions = {0}\n' elif '//BRANCHTAKEN' in line: result += ( f'_asm_analysis_.cond_branches[{branch_index}]++;\n' f'if(_asm_analysis_.branch_bits[{branch_index}] == 0) {{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}] = 1;\n' f'}}\n' ) elif '//BRANCHNOTTAKEN' in line: result += ( f'_asm_analysis_.cond_branches[{branch_index}]++;\n' f'if(_asm_analysis_.branch_bits[{branch_index}] == 1) {{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}] = 0;\n' f'}}\n' ) branch_index += 1 elif 'BPSTART' in line: pass elif 'BPEND' in line: pass else: result += f'{line}\n' return result[:-2]
[docs] def two_bit1(self) -> str: '''Branch prediction using two bits (saturating counter). Returns ------- str C code containing all necessary elements for this branch predictor. ''' result = '' branch_count = 0 branch_index = 0 for line in self.c_code.splitlines(): if '//BRANCHTAKEN' in line: branch_count += 1 for line in self.c_code.splitlines(): if '//BPDEFS' in line: result += f'uint8_t branch_bits[{branch_count}];\n' result += f'int cond_branches[{branch_count}];\n' result += f'int mispredictions[{branch_count}];\n' elif '//BPINIT' in line: if branch_count > 0: result += '.branch_bits = {0}, .cond_branches = {0}, ' result += '.mispredictions = {0}\n' elif '//BRANCHTAKEN' in line: result += ( f'_asm_analysis_.cond_branches[{branch_index}]++;\n' f'if(_asm_analysis_.branch_bits[{branch_index}] == 0 || _asm_analysis_.branch_bits[{branch_index}] == 1) {{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}]++;\n' f'}}\n' f'else if(_asm_analysis_.branch_bits[{branch_index}] == 2) {{\n' f'_asm_analysis_.branch_bits[{branch_index}]++;\n' f'}}\n') elif '//BRANCHNOTTAKEN' in line: result += ( f'_asm_analysis_.cond_branches[{branch_index}]++;\n' f'if(_asm_analysis_.branch_bits[{branch_index}] == 2 || _asm_analysis_.branch_bits[{branch_index}] == 3) {{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}]--;\n' f'}}\n' f'else if(_asm_analysis_.branch_bits[{branch_index}] == 1) {{\n' f'_asm_analysis_.branch_bits[{branch_index}]--;\n' f'}}\n') branch_index += 1 elif 'BPSTART' in line: pass elif 'BPEND' in line: pass else: result += f'{line}\n' return result[:-2]
[docs] def two_bit2(self) -> str: '''Another Branch prediction using two bits (bimodal predictor). Returns ------- str C code containing all necessary elements for this branch predictor. ''' result = '' branch_count = 0 branch_index = 0 for line in self.c_code.splitlines(): if '//BRANCHTAKEN' in line: branch_count += 1 for line in self.c_code.splitlines(): if '//BPDEFS' in line: result += f'uint8_t branch_bits[{branch_count}];\n' result += f'int cond_branches[{branch_count}];\n' result += f'int mispredictions[{branch_count}];\n' elif '//BPINIT' in line: if branch_count > 0: result += '.branch_bits = {0}, .cond_branches = {0}, ' result += '.mispredictions = {0}\n' elif '//BRANCHTAKEN' in line: result += ( f'_asm_analysis_.cond_branches[{branch_index}]++;\n' f'if(_asm_analysis_.branch_bits[{branch_index}] == 0){{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}]++;\n' f'}}\n' f'else if(_asm_analysis_.branch_bits[{branch_index}] == 1) {{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}] += 2;\n' f'}}\n' f'else if(_asm_analysis_.branch_bits[{branch_index}] == 2) {{\n' f'_asm_analysis_.branch_bits[{branch_index}]++;\n' f'}}\n' ) elif '//BRANCHNOTTAKEN' in line: result += ( f'_asm_analysis_.cond_branches[{branch_index}]++;\n' f'if(_asm_analysis_.branch_bits[{branch_index}] == 3){{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}]--;\n' f'}}\n' f'else if(_asm_analysis_.branch_bits[{branch_index}] == 2) {{\n' f'_asm_analysis_.mispredictions[{branch_index}]++;\n' f'_asm_analysis_.branch_bits[{branch_index}] = 0;\n' f'}}\n' f'else if(_asm_analysis_.branch_bits[{branch_index}] == 1) {{\n' f'_asm_analysis_.branch_bits[{branch_index}]--;\n' f'}}\n' ) branch_index += 1 elif 'BPSTART' in line: pass elif 'BPEND' in line: pass else: result += f'{line}\n' return result[:-2]
[docs] def insert_branch_pred(self, method_name: str) -> str: '''Calls the desired branch prediction method. Parameters ---------- method_name : str Name of the desired branch prediction method. Returns ------- str C-code containing the necessary instructions for the branch prediction simulation. ''' if method_name == 'one_bit': return self.one_bit() elif method_name == 'two_bit1': return self.two_bit1() elif method_name == 'two_bit2': return self.two_bit2() else: return self.c_code
[docs] @staticmethod def is_branch_instr(opcode: str, *args) -> bool: if (re.match('^b(?!ic$).*', opcode) or (re.match('(^ldr.*)|(^ldm.*)|(^pop.*)', opcode) and 'pc' in args)): cond = False if re.match('(^ldr.*)|(^ldm.*)', opcode): digit_idx = re.search('\d', opcode).start() if opcode[digit_idx - 2:digit_idx] in COND_CODES: cond = True elif opcode[-2:] in COND_CODES: cond = True return cond return False
[docs] @staticmethod def write_rates(file_path: str, blocks: List[CodeBlock], branch_rates: List[float], branch_map: Dict[int, int]) -> None: asm_lines = [] line_index = 0 with open(file_path, 'r') as f: asm_lines = f.readlines() # Write back every asm line but add the branch prediction rate. with open(file_path, 'w') as f: for block in blocks: for instr in block.instructions: if instr[0] != -1: while line_index < instr[0]: f.write(f' {asm_lines[line_index]}') line_index += 1 if instr[0] in branch_map: branch_rate = branch_rates[branch_map[instr[0]]] branch_str = '{:.2f}'.format(branch_rate) f.write(f'{branch_str} {asm_lines[line_index]}') else: f.write(f' {asm_lines[line_index]}') line_index += 1 while line_index < len(asm_lines): f.write(f' {asm_lines[line_index]}') line_index += 1
BP_METHODS = ['one_bit', 'two_bit1', 'two_bit2'] COND_CODES = ['eq', 'ne', 'ge', 'gt', 'le', 'lt', 'ls', 'cs', 'cc', 'hi', 'mi', 'pl', 'al', 'nv', 'vs', 'vc']