Sunday, December 13, 2009

Using heap as stack and finding stack usage

Recently I learned how a memory allocated from heap can be used as stack and how we can find the stack usage of particular function invocation. The following is a sample code utilizing this technique and also it detects the stack usage of function call.
#include <stdlib.h>
#include <string.h>

static int depth = 0;
static void *prevESP;

void func()
{
    if (depth++ < 10)  // change the 10 to increase/decrease stack usage
        func();
}

int main(int argc, _TCHAR* argv[])
{
    const int STACK_SIZE = 1024 * 1024 * 4; // we allocate 4MB for our heap stack
    unsigned char *heapStack = (unsigned char*)malloc(STACK_SIZE);
    unsigned char *heapStackBottom = heapStack + STACK_SIZE; // move to bottom of the heap stack

    // fill the stack with 0xaa byte so that we can detect the stack
    // usage by scanning non 0xaa byte
    memset(heapStack, 0xaa, STACK_SIZE);

    __asm
    {
        mov prevESP, ESP          // take backup of current stack pointer
        mov ESP, heapStackBottom  // store our heap stack bottom as current stack pointer  
    }

    func();

    __asm
    {
        mov ESP, prevESP  // restore the original stack pointer
    }

    // lets scan for a byte which is not 0xaa. this reveals the
    // last dirty byte from which we can calculate the stack usage
    heapStackBottom = heapStack;
    while (*heapStackBottom++ == 0xaa);

    // the stack started from bottom, subtract it from top and the total size to find the stack usage
    printf ("Heap stack usage %d\n", STACK_SIZE - (heapStackBottom - heapStack));

    free(heapStack);

    return 0;
}
So in the above code we allocate 4MB heap space to be used as stack(lets call this as "heap stack"). Now we need to store the bottom(x86 stack grows downwards, e.g. 2000 -> 1000) of heap stack into ESP register. Before modifying the ESP we need to take the current ESP value which we will restore after function invocation. These operations are done in inlined assembly code. To the stack usage of a function invocation, we fill the allocated heap area with some known pattern(here 0xaa -> 0b10101010) and call the function. When the function returns we start scanning for a byte which is not 0xaa. This will be the last dirty byte from heap stack bottom. We subtract this offset from heap stack top which results the no. of bytes not used. Again subtracting it with total heap stack size we will end up with actual stack usage. Though the above code is writtend for VC++ compiler it can be ported to Linux by changing the inlined assembly with its GCC equivalent.
I thought about the above technique when I was reading setcontext/getcontext/swapcontext POSIX functions and how it could be implemented. As per C99 standard setjmp/longjmp is not guaranteed to work when we call longjmp of a function which is already completed its execution(i.e. returned). This is because the stack is not guaranteed to be same when setjmp was called. The POSIX solves this by using user allocated memory for stack. At this pointed I was wondering how could it be implemented and came up with the above code.

Tuesday, December 08, 2009

A simple TCP redirector in Python

A very simple TCP redirector written in python as an experiment. It redirects any data sent to local port's to a target host, port. It acts as a bridge between local port and the target host, port pair. You can run it by specifying a local port, target host and target host's port. For example the following will redirect all HTTP request to local port 8080 into Google web server
$ ./SimpleTCPRedirector localhost 8080 www.google.com 80
#!/usr/bin/env python

import socket
import threading
import select
import sys

terminateAll = False

class ClientThread(threading.Thread):
 def __init__(self, clientSocket, targetHost, targetPort):
  threading.Thread.__init__(self)
  self.__clientSocket = clientSocket
  self.__targetHost = targetHost
  self.__targetPort = targetPort
  
 def run(self):
  print "Client Thread started"
  
  self.__clientSocket.setblocking(0)
  
  targetHostSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  targetHostSocket.connect((self.__targetHost, self.__targetPort))
  targetHostSocket.setblocking(0)
  
  clientData = ""
  targetHostData = ""
  terminate = False
  while not terminate and not terminateAll:
   inputs = [self.__clientSocket, targetHostSocket]
   outputs = []
   
   if len(clientData) > 0:
    outputs.append(self.__clientSocket)
    
   if len(targetHostData) > 0:
    outputs.append(targetHostSocket)
   
   try:
    inputsReady, outputsReady, errorsReady = select.select(inputs, outputs, [], 1.0)
   except Exception, e:
    print e
    break
    
   for inp in inputsReady:
    if inp == self.__clientSocket:
     try:
      data = self.__clientSocket.recv(4096)
     except Exception, e:
      print e
     
     if data != None:
      if len(data) > 0:
       targetHostData += data
      else:
       terminate = True
    elif inp == targetHostSocket:
     try:
      data = targetHostSocket.recv(4096)
     except Exception, e:
      print e
      
     if data != None:
      if len(data) > 0:
       clientData += data
      else:
       terminate = True
      
   for out in outputsReady:
    if out == self.__clientSocket and len(clientData) > 0:
     bytesWritten = self.__clientSocket.send(clientData)
     if bytesWritten > 0:
      clientData = clientData[bytesWritten:]
    elif out == targetHostSocket and len(targetHostData) > 0:
     bytesWritten = targetHostSocket.send(targetHostData)
     if bytesWritten > 0:
      targetHostData = targetHostData[bytesWritten:]
  
  self.__clientSocket.close()
  targetHostSocket.close()
  print "ClienThread terminating"

if __name__ == '__main__':
 if len(sys.argv) != 5:
  print 'Usage:\n\tpython SimpleTCPRedirector    '
  print 'Example:\n\tpython SimpleTCPRedirector localhost 8080 www.google.com 80'
  sys.exit(0)  
 
 localHost = sys.argv[1]
 localPort = int(sys.argv[2])
 targetHost = sys.argv[3]
 targetPort = int(sys.argv[4])
  
 serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 serverSocket.bind((localHost, localPort))
 serverSocket.listen(5)
 print "Waiting for client..."
 while True:
  try:
   clientSocket, address = serverSocket.accept()
  except KeyboardInterrupt:
   print "\nTerminating..."
   terminateAll = True
   break
  ClientThread(clientSocket, targetHost, targetPort).start()
  
 serverSocket.close()