# basic simulator for our "PRIMA" (Primitive Maschine) 8-bit computer.
#
import datetime

class PRIMA:
    """
    basic simulation model of the PRIMA (primitive machine) von-Neumman
    computer used in our 'Rechnerstrukturen und Betriebssysteme' lecture.
    """

    def __init__( self ):
        self.verbose = True
        self.halt = False

        self.AR = 0
        self.BR = 0
        self.PC = 0
        self.SW = 0
        self.accu = 0
        self.cycle = 0
        self.memory = list( range( 256 ))
        return

    def reset( self ):
        self.cycle = 0
        self.PC = 0
        return 

    def is_halted( self ):
        return self.halt

    def set_verbose( self, b = True ):
        self.verbose = b
        return self.verbose

    def set_PC( self, value ):
        self.PC = value & 0xff
        return 

    def set_AR( self, value ):
        self.AR = value & 0xff
        return 

    def set_BR( self, value ):
        self.BR = value & 0xff
        return 

    def set_SW( self, value ):
        self.SW = value & 0x01
        return 

    def set_accumulator( self, value ):
        self.accu = value & 0xff
        return 

    def clk( self, print_registers=False ):
        """
        execute one clock cycle.
        """
        if self.halt: 
             if print_registers: self.print_registers()
             return 

        phase = self.cycle % 3
        if   phase == 0:
                         self.BR = self.memory[ self.PC ]
                         self.PC = (self.PC + 1) % 256
        elif phase == 1:
                         self.AR = self.memory[ self.PC ]
                         self.PC = (self.PC + 1) % 256
        elif phase == 2:
                         # match+case only introduced in python 3.10, we stay with if/elif/else...
                         opcode = self.BR
                         if   opcode == 0x00: # nop, no operation
                             pass
                         elif opcode == 0x01: # clear, accumulator = 0
                             self.accu = 0

                         elif opcode == 0x02: # load, accu = MEM[ AR ]
                             self.accu = self.memory[ self.AR ]

                         elif opcode == 0x03: # store, MEM[ AR ] = accu
                             self.memory[ self.AR ] = self.accu

                         elif opcode == 0x10: # incr, accu = accu + 1 (mod 256)
                             self.accu = (self.accu + 1) & 0xff

                         elif opcode == 0x11: # decr, accu = accu - 1 (mod 256)
                             self.accu = (self.accu - 1) & 0xff

                         elif opcode == 0x12: # add, accu = accu + MEM[ AR ]
                             self.accu = (self.accu + self.memory[ self.AR ]) & 0xff

                         elif opcode == 0x13: # sub, accu = accu - MEM[ AR ]
                             self.accu = (self.accu - self.memory[ self.AR ]) & 0xff
       
                         elif opcode == 0x20: # neg, bitwise not, accu = ~ accu
                             self.accu = (~self.accu) & 0xff

                         elif opcode == 0x21: # bitwise and, accu = accu & MEM[ AR ]
                             self.accu = self.accu & self.memory[ self.AR ]

                         elif opcode == 0x22: # bitwise or,  accu = accu | MEM[ AR ]
                             self.accu = self.accu | self.memory[ self.AR ]

                         elif opcode == 0x23: # bitwise xor,  accu = accu | MEM[ AR ]
                             self.accu = self.accu ^ self.memory[ self.AR ]

                         elif opcode == 0x40: # jump
                             self.PC   = self.AR

                         elif opcode == 0x41: # branch if zero
                             if self.accu == 0: self.PC   = self.AR

                         elif opcode == 0x42: # branch if negative (accu msb set)
                             if (self.accu & 0x80) != 0: self.PC   = self.AR

                         elif opcode == 0x45: # branch if switch
                             if (self.SW & 0x01) != 0: self.PC = self.AR

                         elif opcode == 0xff: # halt
                             self.halt = True

                         else: # illegal instruction
                             print( "Illegal instruction " + str(opcode) + " (0x" + self.hex2(opcode) + "), ignored." ) 

        if print_registers or self.verbose: self.print_registers()
        self.cycle += 1
        return 

    def instruction( self, print_registers=False ):
        """
        execute or finish one instruction 
        """
        v = self.verbose
        phase = self.cycle % 3
        if    phase == 2:  self.clk(v);                         # execute only to finish instruction
        elif  phase == 1:  self.clk(v); self.clk(v);             # fetch-addr, execute
        else:              
            self.clk(v); self.clk(v); 
            self.clk( v or print_registers ); # fetch-inst, fetch-addr, execute

        # if self.verbose: self.print_registers()
        return 

    def get_mnemonic( self ):
        phase = self.cycle % 3
        opcode = self.BR
        if   phase == 0: return "fetch inst"
        elif phase == 1: return "fetch addr"
        elif opcode == 0x00: return "nop"
        elif opcode == 0x01: return "clear"
        elif opcode == 0x02: return "load" + " " + self.hex2( self.AR )
        elif opcode == 0x03: return "store" + " " + self.hex2( self.AR )

        elif opcode == 0x10: return "incr"
        elif opcode == 0x11: return "decr"
        elif opcode == 0x12: return "add"
        elif opcode == 0x13: return "sub"

        elif opcode == 0x20: return "neg"
        elif opcode == 0x21: return "and"
        elif opcode == 0x22: return "or"
        elif opcode == 0x23: return "xor"

        elif opcode == 0x40: return "jump" + " " + self.hex2( self.AR )
        elif opcode == 0x41: return "bze"
        elif opcode == 0x42: return "bgt"
        elif opcode == 0x45: return "bsw"

        elif opcode == 0xff: return "halt"
        else:
            pass
        return "illegal opcode"


    def print_registers( self ):
        s = ""
        s += "cycle: " + "{:8d}".format( self.cycle ) + "." + str(self.cycle%3) + "   "
        s += "registers(hex)  " 
        s += "PC: " + self.hex2( self.PC ) + "  " 
        s += "BR: " + self.hex2( self.BR ) + "  "
        s += "AR: " + self.hex2( self.AR ) + "  "
        s += "accu: " + self.hex2( self.accu ) + "  "
        s += "SW: " + self.hex2( self.SW ) + "  " 
        s += self.get_mnemonic()
        print( s )
        return 

    def hex2( self, value ): # format value as hex, but without 0x prefix
        return '{:02X}'.format( value )

    def hex4( self, value ): # format value as hex, but without 0x prefix
        return '{:04X}'.format( value )

    def int4( self, value ):
        return '{:4d}'.format( value )

    def write_memory( self, address, value ):
        # print( str( address ) + " -> " + str( value ))
        self.memory[ address&0xff ] = (value & 0xff)
        return

    def read_memory( self, address ):
        return self.memory[ address&0xff ]

    def print_memory( self, start_addr, end_addr, fmt=0 ):
        """
        print a user-readable memory dump to stdout, using the selected formats:
        0: list of decimal values (16 per line)
        1: list of hex values (16 per line),
        2: address plus hex values (16 per line, aligned)
        9: disassembled view
        """
        if fmt == 0:
            j = 0
            s = ""
            for addr in range( start_addr, start_addr+end_addr ):
                s += self.int4( self.read_memory( addr ))
                j += 1
                if j == 16: print( s ); j = 0; s = ""
                else:       s += " "
            if len(s) > 0: print( s )

        elif fmt == 1:
            # s.append( read_memory( addr ) + " " ) for i in range( start_addr, start_addr+end_addr )
            j = 0
            s = ""
            for addr in range( start_addr, start_addr+end_addr ):
                s += self.hex2( self.read_memory( addr ))
                j += 1
                if   j == 16: print( s ); s = ""; j = 0
                elif j ==  8: s += "  "   # extra space for easier reading
                else:         s += " "    # space between values
            if len(s) > 0: print( s )     # print if not already done

        elif fmt == 2:
            # s.append( read_memory( addr ) + " " ) for i in range( start_addr, start_addr+end_addr )
            j = 0
            s = ""
            for addr in range( start_addr, start_addr+end_addr ):
                if   j ==  0: s += self.hex4( addr ); s += ":  " 
                j += 1
                s += self.hex2( self.read_memory( addr ))
                if   j == 16: print( s ); s = ""; j = 0
                elif j == 12: s += "  "    # extra space for easier reading
                elif j ==  8: s += "   "   # extra space for easier reading
                elif j ==  4: s += "  "    # extra space for easier reading
                else:         s +=  " "    # space between between values
            if len(s) > 0: print( s )      # print if not already done

        elif fmt == 9:
            print( "Disassembled memory dump: NOT YET IMPLEMENTED." )

        else:
            print( "Unknown memory format, ignored. Please use help() for details." )
        return 
       

    def load_memory( self, filename ):
        """ 
        load memory contents from the given text file, assuming utf-8 encoding.
        Every line in the file should start with an hex address, colon, and a
        number of space-separated hex-formatted byte values. 
        Lines starting with '#' are interpreted as comments, and are ignored.
        Example:
        # PRIMA demo program 
        0000:  40 30 
        0010:  02 F8 03 FA  01 00 03 FB   02 FA 41 60  02 FB 12 F9
        00f6:  F6 F7
        """ 
        with open( filename, "r", encoding="utf-8" ) as f:
            for line in f:
                if line.startswith( "#" ): print( line, end='' )
                else: 
                    addrdata = line.split( ":" )
                    if len(addrdata) != 2:
                        print( "invalid/empty line ignored: " + line )
                    else: 
                        addr = int( addrdata[0], 16 )
                        for token in addrdata[1].split( " " ):
                            if token == "" or token == "\n": pass
                            else:
                                value = int( token, 16 )
                                self.write_memory( addr, value )
                                addr += 1
        return


    def save_memory( self, filename, start=0, end=256, comment="" ):
        print( "save_memory: writing data to '" + filename + "'" )
        with open( filename, "w", encoding="utf-8" ) as f:
            f.write( "# PRIMA simulator memory dump, created on " + str( datetime.datetime.now() ) + "\n" )
            if comment != "": f.write( "#" + comment + "\n" )

            j = 0
            s = ""
            for addr in range( start, end):
                if   j ==  0: s += self.hex4( addr ); s += ":  " 
                j += 1
                s += self.hex2( self.read_memory( addr ))
                if   j == 16: f.write( s ); f.write( "\n" ); s = ""; j = 0
                elif j == 12: s += "  "    # extra space for easier reading
                elif j ==  8: s += "   "   # extra space for easier reading
                elif j ==  4: s += "  "    # extra space for easier reading
                else:         s +=  " "    # space between between values
            if len(s) > 0: f.write( s ); f.write( "\n" ) # write last line if not already done
        print( "save_memory ok." ) 
        return


    def memset( self, start_addr = 10, values = 0, size = 0):
        """
        starting at the given start address, fill size cells of memory with the given value(s).
        Use this method to clear or initialize an area of memory. Variants:
        memset( start_address, value, size ): fill constant value into range start_address..start_address+size
        memset( start_address, [array] ): use values from array starting at start_address
        memset( start_address, {dict} ): set addr:value key:data pairs
        """

        if type(values) == dict:
            for key in values.keys():
                self.write_memory( addr, values[key] )

        elif type(values) == list:
            for i in range( len( values)):
                self.write_memory( start_addr + i, values[i] )

        else: # assume integer
            for addr in range( start_addr, start_addr+size): 
                self.write_memory( addr, values )
        return 


def demo_1_endless_decrement_loop():
    prima = PRIMA()
    # clear memory, 'load' demo program to decrement mem[253]
    #
    prima.memset( 0, values=0, size=256 )
    prima.memset( 0, values=[1,0, 3,253, 17,0, 3,253, 64,4, 0,0] ) 

    # the loop would repeat forever, but we only execute 40 
    # instructions (3 clocks each)
    prima.set_verbose( True )
    for i in range( 3*40 ): 
        prima.clk()
        prima.clk()
        prima.clk()
        prima.print_memory( 0xe0, 32, 2 )
    return 


def demo_5_strlen( addr ): # count number of chars in string, result in addr 240
    pass 
    return

def demo_6_strcpy( addr1, addr2 ): # copy string starting from addr2 into addr1 
    return 


def test_load_file():
    prima = PRIMA()

    # phase 0: clear memory
    prima.memset( 0, 0, size=256 )
    
    # phase 1: create test file, write some data to it
    with open( "/tmp/prima.txt", "w", encoding="utf-8" ) as f:
        f.write( "# comment line \n" )
        f.write( "## another comment \n" )
        f.write( "0000:  40 30 00 00  00 00 00 00   00 00 00 00  00 00 00 0F\n" )
        f.write( "0010:10\n" )
        f.write( "0018:18\n" )
        f.write( "001b: 1b  1c   1d     1e       1f        \n" )
        f.write( "0020:20 aa bb cc dd ee ff 00 01 02\n" )
        f.write( "# testing extra newline:\n" )
        f.write( "00f7:  F7 F8 F9 \n\n" )
        f.write( "# testing invalid line without address:\n" )
        f.write( "00fA:  FA\n FB\n" )
        f.write( "00FF:  42\n" )
        f.write( "# end of file\n" )

    # phase 2: test load_memory function
    prima.load_memory( "/tmp/prima.txt" )
    prima.print_memory( 0, 256, 2 )
    return 


def selftest():
    print( "PRIMA simulator selftest..." )
    prima = PRIMA()
    prima.memset( 10, 42, 5 )
    prima.print_memory( 0, 30, 0 )

    prima.memset( 0, [42,1,2,3,4,5,6,7,8,9] )
    prima.print_memory( 0, 127, 1 )
    prima.print_memory( 0, 256, 2 )

    prima.print_registers()

    prima.clk()
    prima.instruction()
    prima.instruction()
    return


if __name__ == "__main__":
    # selftest()
    # aufgabe_10_3_b.txt ...
    demo_1_endless_decrement_loop()


