C# で作る簡単プロクシサーバー

追記: 2008/12/28 改訂版


プロクシサーバーの作成は難しいです。ブラウザ側とプロクシ間の keep-alive と、プロクシとオリジンサーバー間の keep-alive ができないといけません。その辺の管理がややこしい。というわけで .NET Framework にぜんぶやってもらうことにした。

.NET Framework には System.Net.HttpListener というクラスがあります。このクラスを使うことで簡単に HTTP サーバーをたてることができます。
そしてさらに System.Net.HttpWebRequest, System.Net.HttpWebResponse というクラスがあります。この2つを使えば Http リクエスト、レスポンスを簡単に扱えます。しかも Keep-Alive してくれてそうな感じがします(HttpWebRequest がやってくれているはず。たぶん)。というわけで、プロクシサーバーに必要な、Http サーバの機能と Http クライアントの機能がそろいました。これらを使ってプロクシサーバーを作ってみます。

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Net;
using System.Net.Sockets;
using System.IO;
using System.Diagnostics;

namespace HttpServerApp
{
	class Program
	{
		static void Main(string[] args)
		{
			// Debug の設定
			Debug.Listeners.Add(new ConsoleTraceListener());
			DefaultTraceListener dtl = (DefaultTraceListener)Debug.Listeners["Default"];
			dtl.LogFileName = "debug.txt"; // ファイルへ出力

			HttpServer server = new HttpServer();
			server.Start();
			while (true)
			{
				string str = Console.ReadLine();
				if (str.ToLower().Equals("end"))
					break;
			}
			server.Stop();
		}
	}

	class HttpServer
	{
		HttpListener listener;

		public HttpServer()
		{
		}

		public void Start()
		{
			Debug.WriteLine("Start", "HttpServer");
			if (listener == null)
			{
				listener = new HttpListener();
				listener.Prefixes.Add(string.Format("http://{0}:{1}/", IPAddress.Loopback, 8080));
			}
			listener.Start();

			listener.BeginGetContext(OnGetContext, listener);
		}

		public void Stop()
		{
			listener.Stop();
			Debug.WriteLine("Stop", "HttpServer");
		}

		private void OnGetContext(IAsyncResult ar)
		{
			try
			{
				HttpListenerContext context = ((HttpListener)ar.AsyncState).EndGetContext(ar);
				OnRequest(context);
				listener.BeginGetContext(OnGetContext, listener);
			}
			catch (HttpListenerException e)
			{
				Debug.WriteLine(e);
			}
		}

		private void OnRequest(HttpListenerContext context)
		{
			Debug.WriteLine(context.Request.RawUrl); // 加工されてないアドレス
			Debug.WriteLine(context.Request.UserHostAddress); // どこから接続されたか

			HttpWebRequest webRequest = WebRequest.Create(context.Request.RawUrl) as HttpWebRequest; // そのまま取得
			// ftp:// とかのリクエストにも対応できるはず

			// リクエストラインとヘッダの設定。ある程度しかしていない。
			webRequest.Method = context.Request.HttpMethod;
			webRequest.ProtocolVersion = context.Request.ProtocolVersion;
			webRequest.KeepAlive = context.Request.KeepAlive;
			if (context.Request.UrlReferrer != null)
				webRequest.Referer = context.Request.UrlReferrer.OriginalString;
			webRequest.UserAgent = context.Request.UserAgent;
			webRequest.CookieContainer = new CookieContainer();
			webRequest.CookieContainer.Add(webRequest.RequestUri, context.Request.Cookies);

			byte[] buffer = new byte[8192];
			// ボディあったら送受信
			if (context.Request.HasEntityBody)
			{
				Debug.WriteLine("Request.HasEntityBody", context.Request.RawUrl);
				webRequest.ContentType = context.Request.ContentType;
				if (context.Request.ContentLength64 >= 0)
					webRequest.ContentLength = context.Request.ContentLength64;
				// transfer-encoding: chunked のことは考えていない
				Stream input = context.Request.InputStream;
				Stream output = webRequest.GetRequestStream();

				while (true)
				{
					int read = input.Read(buffer, 0, buffer.Length);
					if (read == 0)
						break;
					output.Write(buffer, 0, read);
				}

				output.Flush();
				input.Close();
				output.Close();
			}

			// レスポンス取得
			HttpWebResponse webResponse = null;
			try
			{
				webResponse = webRequest.GetResponse() as HttpWebResponse;
			}
			catch (WebException e)
			{
				Debug.WriteLine(e);
			}

			// だめだった時の処理。てきとう
			if (webResponse == null)
			{
				context.Response.ProtocolVersion = HttpVersion.Version11;
				context.Response.StatusCode = 503;
				context.Response.StatusDescription = "Response Error";
				context.Response.ContentLength64 = 0;
				context.Response.Close();
				return;
			}

			// ブラウザへ返すレスポンスの設定。あるていど。
			context.Response.ProtocolVersion = webResponse.ProtocolVersion;
			context.Response.StatusCode = (int)webResponse.StatusCode;
			context.Response.StatusDescription = webResponse.StatusDescription;

			context.Response.AddHeader("Server", webResponse.Server);
			context.Response.ContentType = webResponse.ContentType;
			if (webResponse.ContentLength >= 0)
			{
				Debug.WriteLine("Content-Length=" + webResponse.ContentLength, context.Request.RawUrl);
				context.Response.ContentLength64 = webResponse.ContentLength;
			}
			if (webResponse.GetResponseHeader("Transfer-Encoding").ToLower().Equals("chunked"))
			{
				Debug.WriteLine("chunked", context.Request.RawUrl);
				context.Response.SendChunked = true;
			}
			context.Response.KeepAlive = context.Request.KeepAlive;

			// ボディの送受信
			Stream instream = webResponse.GetResponseStream();
			Stream outstream = context.Response.OutputStream;

			try
			{
				// これでうまくいくんだ・・・。ブロッキングせずにちゃんと 0 が返ってくるな。
				// 切断まで 0 は返ってこないからブロッキングするとおもってたけど・・・。
				while (true)
				{
					int read = instream.Read(buffer, 0, buffer.Length);
					Debug.WriteLine(read + " bytes received", context.Request.RawUrl);
					if (read == 0)
						break;
					outstream.Write(buffer, 0, read);
				}
			}
			catch (Exception e)
			{
				Debug.WriteLine(e);
			}

			try
			{
				webResponse.Close();
				context.Response.Close();
			}
			catch (InvalidOperationException e)
			{
				Debug.WriteLine(e);
			}
			catch (HttpListenerException he)
			{
				Debug.WriteLine(he);
			}
		}
	}
}

とこんな感じで動きます。
例外処理とかてきとうです。変なリクエストが送られてきても HttpListener が勝手にエラーレスポンスを返しているようです。
あとはヘッダフィルタなんか付けたりもできそうです。ボディのフィルタリングがめんどくさそう。Content-Length が変わっちゃうとまずいし。html くらいならいったんボディ全部読み込んでも良いか?
まぁ今回はこんな感じで動きましたってとこまでにします。