Add safe wrapper for Unix.file_descr

OCaml's `file_descr` type can be used after being closed. In the best
case (the FD hasn't been reused), the kernel will return an error.
However, it may also perform the operation on an unrelated FD, breaking
modularity and causing bugs that are extremely hard to debug.
This commit is contained in:
Thomas Leonard 2021-03-18 12:14:52 +00:00
parent 1d9823f190
commit db4e8d527d
5 changed files with 98 additions and 19 deletions

View File

@ -22,10 +22,32 @@ let () = Sys.(set_signal sigpipe Signal_ignore)
type amount = Exactly of int | Upto of int
module FD = struct
type t = {
mutable fd : [`Open of Unix.file_descr | `Closed]
}
let get op = function
| { fd = `Open fd } -> fd
| { fd = `Closed } -> invalid_arg (op ^ ": file descriptor used after calling close!")
let of_unix fd = { fd = `Open fd }
let to_unix = get "to_unix"
let is_open = function
| { fd = `Open _ } -> true
| { fd = `Closed } -> false
let close t =
let fd = get "close" t in
t.fd <- `Closed;
Unix.close fd
end
type rw_req = {
op: [`R|`W];
file_offset: int option;
fd: Unix.file_descr;
fd: FD.t;
len: amount;
buf: Uring.Region.chunk;
mutable cur_off: int;
@ -54,6 +76,7 @@ type t = {
}
let rec submit_rw_req st ({op; file_offset; fd; buf; len; cur_off; _} as req) =
let fd = FD.get "submit_rw_req" fd in
let {uring;io_q;_} = st in
let off = Uring.Region.to_offset buf + cur_off in
let len = match len with Exactly l | Upto l -> l in
@ -76,7 +99,7 @@ let enqueue_read st action (file_offset,fd,buf,len) =
let rec enqueue_poll_add st action fd poll_mask =
Logs.debug (fun l -> l "poll_add: submitting call");
let subm = Uring.poll_add st.uring fd poll_mask (Poll_add action) in
let subm = Uring.poll_add st.uring (FD.get "poll_add" fd) poll_mask (Poll_add action) in
if not subm then (* wait until an sqe is available *)
Queue.push (fun st -> enqueue_poll_add st action fd poll_mask) st.io_q
@ -178,7 +201,7 @@ effect Yield : unit
let yield () =
perform Yield
effect ERead : (int option * Unix.file_descr * Uring.Region.chunk * amount) -> int
effect ERead : (int option * FD.t * Uring.Region.chunk * amount) -> int
let read_exactly ?file_offset fd buf len =
let res = perform (ERead (file_offset, fd, buf, Exactly len)) in
@ -194,7 +217,7 @@ let read_upto ?file_offset fd buf len =
else
res
effect EPoll_add : Unix.file_descr * Uring.Poll_mask.t -> int
effect EPoll_add : FD.t * Uring.Poll_mask.t -> int
let await_readable fd =
let res = perform (EPoll_add (fd, Uring.Poll_mask.(pollin + pollerr))) in
@ -208,7 +231,7 @@ let await_writable fd =
if res < 0 then
raise (Failure (Fmt.strf "await_writable %d" res)) (* FIXME Unix_error *)
effect EWrite : (int option * Unix.file_descr * Uring.Region.chunk * amount) -> int
effect EWrite : (int option * FD.t * Uring.Region.chunk * amount) -> int
let write ?file_offset fd buf len =
let res = perform (EWrite (file_offset, fd, buf, Exactly len)) in
@ -222,6 +245,20 @@ let alloc () = perform Alloc
effect Free : Uring.Region.chunk -> unit
let free buf = perform (Free buf)
let openfile path flags mode =
FD.of_unix (Unix.openfile path flags mode)
let fstat fd =
Unix.fstat (FD.get "fstat" fd)
let shutdown socket command =
Unix.shutdown (FD.get "shutdown" socket) command
let accept socket =
await_readable socket;
let conn, addr = Unix.accept ~cloexec:true (FD.get "accept" socket) in
FD.of_unix conn, addr
let run ?(queue_depth=64) ?(block_size=4096) main =
Log.debug (fun l -> l "starting run");
(* TODO unify this allocation API around baregion/uring *)

View File

@ -16,14 +16,38 @@
type t
(** Wrap [Unix.file_descr] to track whether it has been closed. *)
module FD : sig
type t
val is_open : t -> bool
(** [is_open t] is [true] if {!close t} hasn't been called yet. *)
val close : t -> unit
(** [close t] closes [t].
@raise Invalid_arg if [t] is already closed. *)
val of_unix : Unix.file_descr -> t
(** [of_unix fd] wraps [fd] as an open file descriptor.
This is unsafe if [fd] is closed directly (before or after wrapping it). *)
val to_unix : t -> Unix.file_descr
(** [to_unix t] returns the wrapped descriptor.
This allows unsafe access to the FD.
@raise Invalid_arg if [t] is closed. *)
end
(** {1 Fibre functions} *)
val fork : (unit -> 'a) -> 'a Promise.t
(** [fork fn] starts running [fn ()] and returns a promise for its result. *)
val yield : unit -> unit
(** [yield ()] asks the scheduler to switch to the next runnable task.
The current task remains runnable, but goes to the back of the queue. *)
val sleep : float -> unit
(** [sleep s] blocks until (at least) [s] seconds have passed. *)
(** {1 Memory allocation functions} *)
@ -33,26 +57,45 @@ val free : Uring.Region.chunk -> unit
(** {1 File manipulation functions} *)
val read_upto : ?file_offset:int -> Unix.file_descr -> Uring.Region.chunk -> int -> int
val openfile : string -> Unix.open_flag list -> int -> FD.t
(** Like {!Unix.open_file}. *)
val read_upto : ?file_offset:int -> FD.t -> Uring.Region.chunk -> int -> int
(** [read_upto fd chunk len] reads at most [len] bytes from [fd],
returning as soon as some data is available.
@param file_offset Read from the given position in [fd] (default: 0).
@raise End_of_file Raised if all data has already been read. *)
val read_exactly : ?file_offset:int -> Unix.file_descr -> Uring.Region.chunk -> int -> unit
val read_exactly : ?file_offset:int -> FD.t -> Uring.Region.chunk -> int -> unit
(** [read_exactly fd chunk len] reads exactly [len] bytes from [fd],
performing multiple read operations if necessary.
@param file_offset Read from the given position in [fd] (default: 0).
@raise End_of_file Raised if the stream ends before [len] bytes have been read. *)
val write : ?file_offset:int -> Unix.file_descr -> Uring.Region.chunk -> int -> unit
val write : ?file_offset:int -> FD.t -> Uring.Region.chunk -> int -> unit
(** [write fd buf len] writes exactly [len] bytes from [buf] to [fd].
It blocks until the OS confirms the write is done,
and resubmits automatically if the OS doesn't write all of it at once. *)
val await_readable : Unix.file_descr -> unit
val await_readable : FD.t -> unit
(** [await_readable fd] blocks until [fd] is readable (or has an error). *)
val await_writable : Unix.file_descr -> unit
val await_writable : FD.t -> unit
(** [await_writable fd] blocks until [fd] is writable (or has an error). *)
val fstat : FD.t -> Unix.stats
(** Like {!Unix.fstat}. *)
(** {1 Sockets} *)
val accept : FD.t -> (FD.t * Unix.sockaddr)
(** [accept t] blocks until a new connection is received on listening socket [t].
It returns the new connection and the address of the connecting peer.
The new connection has the close-on-exec flag set automatically. *)
val shutdown : FD.t -> Unix.shutdown_command -> unit
(** Like {!Unix.shutdown}. *)
(** {1 Main Loop} *)
val run : ?queue_depth:int -> ?block_size:int -> (unit -> unit) -> unit

View File

@ -9,8 +9,7 @@ let setup_log level =
let () =
setup_log (Some Logs.Debug);
(* TODO expose openfile from euring *)
let fd = Unix.(handle_unix_error (openfile "test.txt" [O_RDONLY]) 0) in
let fd = Unix.handle_unix_error (Eunix.openfile "test.txt" Unix.[O_RDONLY]) 0 in
run (fun () ->
let buf = alloc () in
let _ = read_exactly fd buf 5 in

View File

@ -24,12 +24,12 @@ let copy_file infd outfd insize block_size =
let run_cp block_size queue_depth infile outfile () =
let open Unix in
let infd = openfile infile [O_RDONLY] 0 in
let outfd = openfile outfile [O_WRONLY; O_CREAT; O_TRUNC] 0o644 in
let insize = fstat infd |> fun {st_size; _} -> st_size in
let infd = Eunix.openfile infile [O_RDONLY] 0 in
let outfd = Eunix.openfile outfile [O_WRONLY; O_CREAT; O_TRUNC] 0o644 in
let insize = Eunix.fstat infd |> fun {st_size; _} -> st_size in
Logs.debug (fun l -> l "eurcp: %s -> %s size %d queue %d bs %d"
infile outfile insize queue_depth block_size);
U.run ~queue_depth ~block_size (fun () -> copy_file infd outfd insize block_size);
Logs.debug (fun l -> l "eurcp: done");
close outfd;
close infd
Eunix.FD.close outfd;
Eunix.FD.close infd

View File

@ -49,7 +49,7 @@ let test_promise_exn () =
let read_one_byte r =
Eunix.fork (fun () ->
Eunix.await_readable r;
Eunix.await_readable (Eunix.FD.of_unix r);
let b = Bytes.create 1 in
let got = Unix.read r b 0 1 in
assert (got = 1);
@ -61,7 +61,7 @@ let test_poll_add () =
let r, w = Unix.pipe () in
let thread = read_one_byte r in
Eunix.yield ();
Eunix.await_writable w;
Eunix.await_writable (Eunix.FD.of_unix w);
let sent = Unix.write w (Bytes.of_string "!") 0 1 in
assert (sent = 1);
let result = Promise.await thread in