Wednesday, March 09, 2011

Mapping Sockets to a Process In .NET Code

One feature added to Fiddler a few years ago is the ability to map a given HTTP request back to the local process that initiated it. It turns out that this requires a bit of interesting code, because the .NET Framework itself doesn’t expose any built-in access to the relevant IPHelper APIs that provide this information.
I found a number of samples on the web, but for Fiddler, performance is a critical consideration because Fiddler needs to determine the originating process for every new connection. Hence, I’ve written the following code, which maximizes performance by minimizing copies between Windows and managed code.

// This sample is provided "AS IS" and confers no warranties.
// You are granted a non-exclusive, worldwide, royalty-free license to reproduce this code,
// prepare derivative works, and distribute it or any derivative works that you create.
//
// This class invokes the Windows IPHelper APIs that allow us to map sockets to processes.
//
// We could consider a cache of recent hits to improve performance, but the performance is already pretty good, and 
// creating a reasonable cache expiration policy could prove tricky. Client connection reuse already provides a significant
// optimization as it behaves in the same way as an explicit cache would.
//

using System;
using System.Collections.Generic;
using System.Text;
using System.Runtime.InteropServices;
using System.Net.NetworkInformation;
using System.Net;
using System.Diagnostics;
using System.Collections;
namespace Fiddler
{
    internal class Winsock
    {
        #region IPHelper_PInvokes
        private const int AF_INET = 2;              // IPv4
        private const int AF_INET6 = 23;            // IPv6
        private const int ERROR_INSUFFICIENT_BUFFER = 0x7a;
        private const int NO_ERROR = 0x0;
        // Note: C++'s ulong is ALWAYS 32bits, unlike C#'s ulong. See http://medo64.blogspot.com/2009/05/why-ulong-is-32-bit-even-on-64-bit.html
        [DllImport("iphlpapi.dll", ExactSpelling = true, SetLastError = true)]
        private static extern uint GetExtendedTcpTable(IntPtr pTcpTable, ref UInt32 dwTcpTableLength, [MarshalAs(UnmanagedType.Bool)] bool sort, UInt32 ipVersion, TcpTableType tcpTableType, UInt32 reserved);
        /// <summary>
        /// Enumeration of possible queries that can be issued using GetExtendedTcpTable
        /// </summary>
        private enum TcpTableType
        {
            BasicListener,
            BasicConnections,
            BasicAll,
            OwnerPidListener,
            OwnerPidConnections,
            OwnerPidAll,
            OwnerModuleListener,
            OwnerModuleConnections,
            OwnerModuleAll
        }

/* This code is now obsolete as I'm now using pointer-arithmetic to directly access the table rows instead of mapping structs on top of the 
 * returned block of data. I'm keeping the code here for now for debugging purposes.
        [StructLayout(LayoutKind.Sequential)]
        private struct TcpRow
        {
            [MarshalAs(UnmanagedType.U4)]
            internal TcpState state;
            [MarshalAs(UnmanagedType.U4)]
            internal UInt32 localAddr;
            [MarshalAs(UnmanagedType.U4)]
            internal UInt32 localPortInNetworkOrder;
            [MarshalAs(UnmanagedType.U4)]
            internal UInt32 remoteAddr;
            [MarshalAs(UnmanagedType.U4)]
            internal UInt32 remotePortInNetworkOrder;
            [MarshalAs(UnmanagedType.U4)]
            internal Int32 owningPid;
        }

        private static string TcpRowToString(TcpRow rowInput)
        {
            return String.Format(">{0}:{1} to {2}:{3} is {4} by 0x{5:x}",
                (rowInput.localAddr & 0xFF) + "." + ((rowInput.localAddr & 0xFF00) >> 8) + "." + ((rowInput.localAddr & 0xFF0000) >> 16) + "." + ((rowInput.localAddr & 0xFF000000) >> 24),
                ((rowInput.localPortInNetworkOrder & 0xFF00) >> 8) + ((rowInput.localPortInNetworkOrder & 0xFF) << 8),
                (rowInput.remoteAddr & 0xFF) + "." + ((rowInput.remoteAddr & 0xFF00) >> 8) + "." + ((rowInput.remoteAddr & 0xFF0000) >> 16) + "." + ((rowInput.remoteAddr & 0xFF000000) >> 24),
                ((rowInput.remotePortInNetworkOrder & 0xFF00) >> 8) + ((rowInput.remotePortInNetworkOrder & 0xFF) << 8),
                rowInput.state,
                rowInput.owningPid);
        }
 */

Read more: Fiddler Web Debugger