diff --git a/lib_eioio/eunix.ml b/lib_eioio/eunix.ml index a8872de..45e252d 100644 --- a/lib_eioio/eunix.ml +++ b/lib_eioio/eunix.ml @@ -22,10 +22,32 @@ let () = Sys.(set_signal sigpipe Signal_ignore) type amount = Exactly of int | Upto of int +module FD = struct + type t = { + mutable fd : [`Open of Unix.file_descr | `Closed] + } + + let get op = function + | { fd = `Open fd } -> fd + | { fd = `Closed } -> invalid_arg (op ^ ": file descriptor used after calling close!") + + let of_unix fd = { fd = `Open fd } + let to_unix = get "to_unix" + + let is_open = function + | { fd = `Open _ } -> true + | { fd = `Closed } -> false + + let close t = + let fd = get "close" t in + t.fd <- `Closed; + Unix.close fd +end + type rw_req = { op: [`R|`W]; file_offset: int option; - fd: Unix.file_descr; + fd: FD.t; len: amount; buf: Uring.Region.chunk; mutable cur_off: int; @@ -54,6 +76,7 @@ type t = { } let rec submit_rw_req st ({op; file_offset; fd; buf; len; cur_off; _} as req) = + let fd = FD.get "submit_rw_req" fd in let {uring;io_q;_} = st in let off = Uring.Region.to_offset buf + cur_off in let len = match len with Exactly l | Upto l -> l in @@ -76,7 +99,7 @@ let enqueue_read st action (file_offset,fd,buf,len) = let rec enqueue_poll_add st action fd poll_mask = Logs.debug (fun l -> l "poll_add: submitting call"); - let subm = Uring.poll_add st.uring fd poll_mask (Poll_add action) in + let subm = Uring.poll_add st.uring (FD.get "poll_add" fd) poll_mask (Poll_add action) in if not subm then (* wait until an sqe is available *) Queue.push (fun st -> enqueue_poll_add st action fd poll_mask) st.io_q @@ -178,7 +201,7 @@ effect Yield : unit let yield () = perform Yield -effect ERead : (int option * Unix.file_descr * Uring.Region.chunk * amount) -> int +effect ERead : (int option * FD.t * Uring.Region.chunk * amount) -> int let read_exactly ?file_offset fd buf len = let res = perform (ERead (file_offset, fd, buf, Exactly len)) in @@ -194,7 +217,7 @@ let read_upto ?file_offset fd buf len = else res -effect EPoll_add : Unix.file_descr * Uring.Poll_mask.t -> int +effect EPoll_add : FD.t * Uring.Poll_mask.t -> int let await_readable fd = let res = perform (EPoll_add (fd, Uring.Poll_mask.(pollin + pollerr))) in @@ -208,7 +231,7 @@ let await_writable fd = if res < 0 then raise (Failure (Fmt.strf "await_writable %d" res)) (* FIXME Unix_error *) -effect EWrite : (int option * Unix.file_descr * Uring.Region.chunk * amount) -> int +effect EWrite : (int option * FD.t * Uring.Region.chunk * amount) -> int let write ?file_offset fd buf len = let res = perform (EWrite (file_offset, fd, buf, Exactly len)) in @@ -222,6 +245,20 @@ let alloc () = perform Alloc effect Free : Uring.Region.chunk -> unit let free buf = perform (Free buf) +let openfile path flags mode = + FD.of_unix (Unix.openfile path flags mode) + +let fstat fd = + Unix.fstat (FD.get "fstat" fd) + +let shutdown socket command = + Unix.shutdown (FD.get "shutdown" socket) command + +let accept socket = + await_readable socket; + let conn, addr = Unix.accept ~cloexec:true (FD.get "accept" socket) in + FD.of_unix conn, addr + let run ?(queue_depth=64) ?(block_size=4096) main = Log.debug (fun l -> l "starting run"); (* TODO unify this allocation API around baregion/uring *) diff --git a/lib_eioio/eunix.mli b/lib_eioio/eunix.mli index fb0fd57..66a406b 100644 --- a/lib_eioio/eunix.mli +++ b/lib_eioio/eunix.mli @@ -16,14 +16,38 @@ type t +(** Wrap [Unix.file_descr] to track whether it has been closed. *) +module FD : sig + type t + + val is_open : t -> bool + (** [is_open t] is [true] if {!close t} hasn't been called yet. *) + + val close : t -> unit + (** [close t] closes [t]. + @raise Invalid_arg if [t] is already closed. *) + + val of_unix : Unix.file_descr -> t + (** [of_unix fd] wraps [fd] as an open file descriptor. + This is unsafe if [fd] is closed directly (before or after wrapping it). *) + + val to_unix : t -> Unix.file_descr + (** [to_unix t] returns the wrapped descriptor. + This allows unsafe access to the FD. + @raise Invalid_arg if [t] is closed. *) +end + (** {1 Fibre functions} *) val fork : (unit -> 'a) -> 'a Promise.t (** [fork fn] starts running [fn ()] and returns a promise for its result. *) val yield : unit -> unit +(** [yield ()] asks the scheduler to switch to the next runnable task. + The current task remains runnable, but goes to the back of the queue. *) val sleep : float -> unit +(** [sleep s] blocks until (at least) [s] seconds have passed. *) (** {1 Memory allocation functions} *) @@ -33,26 +57,45 @@ val free : Uring.Region.chunk -> unit (** {1 File manipulation functions} *) -val read_upto : ?file_offset:int -> Unix.file_descr -> Uring.Region.chunk -> int -> int +val openfile : string -> Unix.open_flag list -> int -> FD.t +(** Like {!Unix.open_file}. *) + +val read_upto : ?file_offset:int -> FD.t -> Uring.Region.chunk -> int -> int (** [read_upto fd chunk len] reads at most [len] bytes from [fd], returning as soon as some data is available. @param file_offset Read from the given position in [fd] (default: 0). @raise End_of_file Raised if all data has already been read. *) -val read_exactly : ?file_offset:int -> Unix.file_descr -> Uring.Region.chunk -> int -> unit +val read_exactly : ?file_offset:int -> FD.t -> Uring.Region.chunk -> int -> unit (** [read_exactly fd chunk len] reads exactly [len] bytes from [fd], performing multiple read operations if necessary. @param file_offset Read from the given position in [fd] (default: 0). @raise End_of_file Raised if the stream ends before [len] bytes have been read. *) -val write : ?file_offset:int -> Unix.file_descr -> Uring.Region.chunk -> int -> unit +val write : ?file_offset:int -> FD.t -> Uring.Region.chunk -> int -> unit +(** [write fd buf len] writes exactly [len] bytes from [buf] to [fd]. + It blocks until the OS confirms the write is done, + and resubmits automatically if the OS doesn't write all of it at once. *) -val await_readable : Unix.file_descr -> unit +val await_readable : FD.t -> unit (** [await_readable fd] blocks until [fd] is readable (or has an error). *) -val await_writable : Unix.file_descr -> unit +val await_writable : FD.t -> unit (** [await_writable fd] blocks until [fd] is writable (or has an error). *) +val fstat : FD.t -> Unix.stats +(** Like {!Unix.fstat}. *) + +(** {1 Sockets} *) + +val accept : FD.t -> (FD.t * Unix.sockaddr) +(** [accept t] blocks until a new connection is received on listening socket [t]. + It returns the new connection and the address of the connecting peer. + The new connection has the close-on-exec flag set automatically. *) + +val shutdown : FD.t -> Unix.shutdown_command -> unit +(** Like {!Unix.shutdown}. *) + (** {1 Main Loop} *) val run : ?queue_depth:int -> ?block_size:int -> (unit -> unit) -> unit diff --git a/tests/basic_eunix.ml b/tests/basic_eunix.ml index 2e19d4c..2abada8 100644 --- a/tests/basic_eunix.ml +++ b/tests/basic_eunix.ml @@ -9,8 +9,7 @@ let setup_log level = let () = setup_log (Some Logs.Debug); - (* TODO expose openfile from euring *) - let fd = Unix.(handle_unix_error (openfile "test.txt" [O_RDONLY]) 0) in + let fd = Unix.handle_unix_error (Eunix.openfile "test.txt" Unix.[O_RDONLY]) 0 in run (fun () -> let buf = alloc () in let _ = read_exactly fd buf 5 in diff --git a/tests/eurcp_lib.ml b/tests/eurcp_lib.ml index 6e11ad1..c5c7fad 100644 --- a/tests/eurcp_lib.ml +++ b/tests/eurcp_lib.ml @@ -24,12 +24,12 @@ let copy_file infd outfd insize block_size = let run_cp block_size queue_depth infile outfile () = let open Unix in - let infd = openfile infile [O_RDONLY] 0 in - let outfd = openfile outfile [O_WRONLY; O_CREAT; O_TRUNC] 0o644 in - let insize = fstat infd |> fun {st_size; _} -> st_size in + let infd = Eunix.openfile infile [O_RDONLY] 0 in + let outfd = Eunix.openfile outfile [O_WRONLY; O_CREAT; O_TRUNC] 0o644 in + let insize = Eunix.fstat infd |> fun {st_size; _} -> st_size in Logs.debug (fun l -> l "eurcp: %s -> %s size %d queue %d bs %d" infile outfile insize queue_depth block_size); U.run ~queue_depth ~block_size (fun () -> copy_file infd outfd insize block_size); Logs.debug (fun l -> l "eurcp: done"); - close outfd; - close infd + Eunix.FD.close outfd; + Eunix.FD.close infd diff --git a/tests/test.ml b/tests/test.ml index bc0e6a0..1a0b399 100644 --- a/tests/test.ml +++ b/tests/test.ml @@ -49,7 +49,7 @@ let test_promise_exn () = let read_one_byte r = Eunix.fork (fun () -> - Eunix.await_readable r; + Eunix.await_readable (Eunix.FD.of_unix r); let b = Bytes.create 1 in let got = Unix.read r b 0 1 in assert (got = 1); @@ -61,7 +61,7 @@ let test_poll_add () = let r, w = Unix.pipe () in let thread = read_one_byte r in Eunix.yield (); - Eunix.await_writable w; + Eunix.await_writable (Eunix.FD.of_unix w); let sent = Unix.write w (Bytes.of_string "!") 0 1 in assert (sent = 1); let result = Promise.await thread in