C# と .NET Framework で作る簡単プロキシサーバ

前回よりも安定して動作するようになりました。
このプログラムは .NET Framework で用意されている HttpListener クラスと HttpWebRequest を使ってプロキシサーバ (Proxy Server) を実現します。
System.Net.HttpListener の仕様上、プラットフォームに制限があります。

このクラスは、Windows XP SP2 または Windows Server 2003 のオペレーティング システムを実行しているコンピュータでしか使用できません。それ以前のオペレーティング システムを実行しているコンピュータで HttpListener オブジェクトを作成しようとすると、コンストラクタから PlatformNotSupportedException 例外がスローされます。

非常にテキトーな作りになっていますので、実用的に使うことはできませんが、ニコニコ動画YouTube のキャッシュサーバなどは実現可能かと思われます。
ヘッダフィルタや HTML の書き換えなども実現できるとおもいます。
以下ソースコード

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Net;
using System.Diagnostics;
using System.IO;

namespace Samples
{
    class HttpServer
    {
        HttpListener listener;

        public bool IsListening
        {
            get
            {
                return (listener != null) && listener.IsListening;
            }
        }

        public HttpServer()
        {
            listener = null;
        }

        public void Start()
        {
            if (listener != null)
                return;

            Debug.WriteLine("Enter HttpServer::Start");

            listener = new HttpListener();
            listener.Prefixes.Add(string.Format("http://{0}:{1}/", IPAddress.Loopback, 8081));
            listener.Start();

            listener.BeginGetContext(EndGetContext, listener);
            Debug.WriteLine("Exit HttpServer::Start");
        }

        public void Stop()
        {
            if (listener == null)
                return;

            Debug.WriteLine("Enter HttpServer::Stop");
            listener.Stop();
            listener.Close();
            listener = null;
            Debug.WriteLine("Exit HttpServer::Stop");
        }

        private void EndGetContext(IAsyncResult ar)
        {
            Debug.WriteLine("Enter HttpServer::EndGetContext");
            HttpListener listener = ar.AsyncState as HttpListener;
            if (!listener.IsListening) // 呼び出した listener が Stop されているなら何もしない
            {
                Debug.WriteLine("Exit HttpServer::EndGetContext listener stopped");
                return;
            }

            HttpListenerContext context = null;
            try
            {
                listener.BeginGetContext(EndGetContext, listener);
                context = listener.EndGetContext(ar);
                HandleRequest(context);
            }
            catch (Exception ex)
            {
                Debug.WriteLine("Exception HttpServer::EndGetContext " + ex.Message);
                if (context != null)
                    context.Response.Abort();
            }
            finally
            {
                if (context != null)
                    context.Response.Close();
                Debug.WriteLine("Exit HttpServer::EndGetContext");
            }
        }

        private HttpWebRequest CreateRequest(HttpListenerContext context, bool keepalive)
        {
            // HttpListenerRequest と同じ HttpWebRequest を作る。
            HttpListenerRequest req = context.Request;
            HttpWebRequest webRequest = WebRequest.Create(req.RawUrl) as HttpWebRequest;
            webRequest.Method = req.HttpMethod;
            webRequest.ProtocolVersion = HttpVersion.Version11;

            // 接続してきたクライアントが切断しても、ほかのクライアントでこの WebRequest の接続を再利用できるはずなので常に true でもいいか?
            webRequest.KeepAlive = keepalive;

            // HttpWebRequest の制限がきついのでヘッダごとに対応
            for (int i = 0; i < req.Headers.Count; i++)
            {
                string name = req.Headers.GetKey(i).ToLower();
                string value = req.Headers.Get(i).ToLower();

                switch (name)
                {
                    case "host":
                        break; // WebRequest.Create で適切に設定されているはず あとで確認
                    case "connection":
                    case "proxy-connection":
                        webRequest.KeepAlive = keepalive; // TODO: keepalive の取得はここで行う。
                        break;
                    case "referer":
                        webRequest.Referer = value;
                        break;
                    case "user-agent":
                        webRequest.UserAgent = value;
                        break;
                    case "accept":
                        webRequest.Accept = value;
                        break;
                    case "content-length":
                        webRequest.ContentLength = req.ContentLength64;
                        break;
                    case "content-type":
                        webRequest.ContentType = value;
                        break;
                    case "if-modified-since":
                        webRequest.IfModifiedSince = DateTime.Parse(value);
                        break;
                    default:
                        try
                        {
                            // その他。上以外にも個別に対応しなければならないものがあるが面倒なのでパス
                            webRequest.Headers.Set(name, value);
                        }
                        catch
                        {
                            Debug.WriteLine("Exception HttpServer::CreateRequest header=" + name);
                        }
                        break;
                }
            }

            return webRequest;
        }

        private void HandleRequest(HttpListenerContext context)
        {
            // どういう処理をさせるのかを決める。
            HttpListenerRequest req = context.Request;
            HttpListenerResponse res = context.Response;

            // どこから接続されたかと、加工されていないアドレス
            Debug.WriteLine("Info HttpServer::HandleRequest " + string.Format("UserHost={0}: Request={1}", req.UserHostAddress, req.RawUrl));

            bool keepalive = req.KeepAlive; // 常に false らしい。バグらしい。
            if (!string.IsNullOrEmpty(req.Headers["Connection"]) && req.Headers["Connection"].IndexOf("keep-alive", StringComparison.InvariantCultureIgnoreCase) >= 0 ||
                !string.IsNullOrEmpty(req.Headers["Proxy-Connection"]) && req.Headers["Proxy-Connection"].IndexOf("keep-alive", StringComparison.InvariantCultureIgnoreCase) >= 0)
                keepalive = true; // バグ対策?

            if (req.RawUrl.StartsWith("/") || req.RawUrl.StartsWith("http://local.ptron/"))
                ProcessLocalRequest(context, keepalive);
            else // プロクシサーバとしての振る舞い
                ProcessProxyRequest(context, keepalive);
        }

        private void ProcessProxyRequest(HttpListenerContext context, bool keepalive)
        {
            // プロキシサーバとしての動作。
            HttpListenerRequest req = context.Request;
            HttpListenerResponse res = context.Response;

            HttpWebRequest webRequest = CreateRequest(context, keepalive);
            // ボディあったら送受信
            if (req.HasEntityBody) // リクエストのボディがある (POST とか)
                Relay(req.InputStream, webRequest.GetRequestStream());

            // レスポンス取得
            HttpWebResponse webResponse = null;
            try
            {
                webResponse = webRequest.GetResponse() as HttpWebResponse;
            }
            catch (WebException e)
            {
                webResponse = e.Response as HttpWebResponse; // レスポンスがあればとる。304 とかの場合。なければ null になる。
                Debug.WriteLine("Exception HttpServer::ProcessProxyRequest " + e.Message);
            }

            // だめだった時の処理。てきとう
            if (webResponse == null)
            {
                SendResponse(context, 502, "Bad Gateway", keepalive, null);
                return;
            }

            // ブラウザへ返すレスポンスの設定。あるていど。
            res.ProtocolVersion = HttpVersion.Version11; // 常に HTTP/1.1 としておく
            res.StatusCode = (int)webResponse.StatusCode;
            res.StatusDescription = webResponse.StatusDescription;

            res.KeepAlive = keepalive;

            for (int i = 0; i < webResponse.Headers.Count; i++)
            {
                string name = webResponse.Headers.GetKey(i).ToLower();
                string value = webResponse.Headers.Get(i).ToLower();

                switch (name)
                {
                    case "content-length":
                        res.ContentLength64 = webResponse.ContentLength;
                        break;
                    case "keep-alive": // どうやって設定しようか...
                        Debug.WriteLine("Info HttpServer::ProcessProxyRequest keep-alive: " + value);
                        break;
                    case "transfer-encoding":
                        res.SendChunked = value.IndexOf("chunked") >= 0 ? true : false;
                        break;
                    default:
                        try
                        {
                            res.Headers.Set(name, value);
                        }
                        catch
                        {
                            Debug.WriteLine("Exception HttpServer::ProcessProxyRequest header=" + name);
                        }
                        break;
                }
            }

            Relay(webResponse.GetResponseStream(), res.OutputStream);

            webResponse.Close();
        }

        private void ProcessLocalRequest(HttpListenerContext context, bool keepalive)
        {
            HttpListenerRequest req = context.Request;

            // 通常の HTTP サーバとしての振る舞い。または local.ptron へのアクセス。
            if (req.RawUrl.Equals("/") || req.RawUrl.Equals("http://local.ptron/"))
                SendResponse(context, 200, "OK", keepalive, Encoding.Default.GetBytes("Hello World"));
            else // favicon とか取りに来るので。
                SendResponse(context, 404, "Not Found", keepalive, Encoding.Default.GetBytes("404 not found"));
        }

        private void Relay(Stream input, Stream output)
        {
            // Stream から読めなくなるまで送受信。
            byte[] buffer = new byte[4096];
            while (true)
            {
                int bytesRead = input.Read(buffer, 0, buffer.Length);
                if (bytesRead == 0)
                    break;
                output.Write(buffer, 0, bytesRead);
            }

            input.Close();
            output.Close();
        }

        private void SendResponse(HttpListenerContext context, int code, string description, bool keepalive, byte[] body)
        {
            context.Response.StatusCode = code;
            context.Response.StatusDescription = description;
            context.Response.ProtocolVersion = HttpVersion.Version11;
            context.Response.KeepAlive = keepalive;
            if (body != null)
            {
                context.Response.ContentType = "text/plain";
                context.Response.ContentLength64 = body.Length;

                context.Response.OutputStream.Write(body, 0, body.Length);
                context.Response.OutputStream.Close();
            }
            else
            {
                context.Response.ContentLength64 = 0;
            }
        }

        private void SendFile(HttpListenerContext context, int code, string description, bool keepalive, byte[] body, string contentType)
        {
            context.Response.StatusCode = code;
            context.Response.StatusDescription = description;
            context.Response.ProtocolVersion = HttpVersion.Version11;
            context.Response.KeepAlive = keepalive;
            if (body != null)
            {
                context.Response.ContentType = contentType;
                context.Response.ContentLength64 = body.Length;

                context.Response.OutputStream.Write(body, 0, body.Length);
                context.Response.OutputStream.Close();
            }
            else
            {
                context.Response.ContentLength64 = 0;
            }
        }
    }
}