--- sys/netinet/in_mcast.c.orig +++ sys/netinet/in_mcast.c @@ -2524,6 +2524,7 @@ { struct epoch_tracker et; struct __msfilterreq msfr; + struct sockaddr_storage *kss; sockunion_t *gsa; struct ifnet *ifp; struct in_mfilter *imf; @@ -2536,9 +2537,6 @@ if (error) return (error); - if (msfr.msfr_nsrcs > in_mcast_maxsocksrc) - return (ENOBUFS); - if ((msfr.msfr_fmode != MCAST_EXCLUDE && msfr.msfr_fmode != MCAST_INCLUDE)) return (EINVAL); @@ -2551,13 +2549,24 @@ if (!IN_MULTICAST(ntohl(gsa->sin.sin_addr.s_addr))) return (EINVAL); + if (msfr.msfr_nsrcs > in_mcast_maxsocksrc) + return (ENOBUFS); + kss = mallocarray(msfr.msfr_nsrcs, sizeof(struct sockaddr_storage), + M_TEMP, M_WAITOK); + error = copyin(msfr.msfr_srcs, kss, + sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs); + if (error) + goto out_inp_unlocked; + gsa->sin.sin_port = 0; /* ignore port */ NET_EPOCH_ENTER(et); ifp = ifnet_byindex(msfr.msfr_ifindex); NET_EPOCH_EXIT(et); /* XXXGL: unsafe ifp */ - if (ifp == NULL) - return (EADDRNOTAVAIL); + if (ifp == NULL) { + error = EADDRNOTAVAIL; + goto out_inp_unlocked; + } IN_MULTI_LOCK(); @@ -2589,24 +2598,9 @@ if (msfr.msfr_nsrcs > 0) { struct in_msource *lims; struct sockaddr_in *psin; - struct sockaddr_storage *kss, *pkss; + struct sockaddr_storage *pkss; int i; - INP_WUNLOCK(inp); - - CTR2(KTR_IGMPV3, "%s: loading %lu source list entries", - __func__, (unsigned long)msfr.msfr_nsrcs); - kss = malloc(sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs, - M_TEMP, M_WAITOK); - error = copyin(msfr.msfr_srcs, kss, - sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs); - if (error) { - free(kss, M_TEMP); - return (error); - } - - INP_WLOCK(inp); - /* * Mark all source filters as UNDEFINED at t1. * Restore new group filter mode, as imf_leave() @@ -2641,7 +2635,6 @@ break; lims->imsl_st[1] = imf->imf_st[1]; } - free(kss, M_TEMP); } if (error) @@ -2678,6 +2671,8 @@ out_inp_locked: INP_WUNLOCK(inp); IN_MULTI_UNLOCK(); +out_inp_unlocked: + free(kss, M_TEMP); return (error); } --- sys/netinet6/in6_mcast.c.orig +++ sys/netinet6/in6_mcast.c @@ -2489,6 +2489,7 @@ { struct __msfilterreq msfr; struct epoch_tracker et; + struct sockaddr_storage *kss; sockunion_t *gsa; struct ifnet *ifp; struct in6_mfilter *imf; @@ -2501,9 +2502,6 @@ if (error) return (error); - if (msfr.msfr_nsrcs > in6_mcast_maxsocksrc) - return (ENOBUFS); - if (msfr.msfr_fmode != MCAST_EXCLUDE && msfr.msfr_fmode != MCAST_INCLUDE) return (EINVAL); @@ -2516,19 +2514,31 @@ if (!IN6_IS_ADDR_MULTICAST(&gsa->sin6.sin6_addr)) return (EINVAL); + if (msfr.msfr_nsrcs > in6_mcast_maxsocksrc) + return (ENOBUFS); + kss = mallocarray(msfr.msfr_nsrcs, sizeof(struct sockaddr_storage), + M_TEMP, M_WAITOK); + error = copyin(msfr.msfr_srcs, kss, + sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs); + if (error) + goto out_in6p_unlocked; + gsa->sin6.sin6_port = 0; /* ignore port */ NET_EPOCH_ENTER(et); ifp = ifnet_byindex(msfr.msfr_ifindex); NET_EPOCH_EXIT(et); - if (ifp == NULL) - return (EADDRNOTAVAIL); + if (ifp == NULL) { + error = EADDRNOTAVAIL; + goto out_in6p_unlocked; + } (void)in6_setscope(&gsa->sin6.sin6_addr, ifp, NULL); /* * Take the INP write lock. * Check if this socket is a member of this group. */ + IN6_MULTI_LOCK(); imo = in6p_findmoptions(inp); imf = im6o_match_group(imo, ifp, &gsa->sa); if (imf == NULL) { @@ -2553,24 +2563,9 @@ if (msfr.msfr_nsrcs > 0) { struct in6_msource *lims; struct sockaddr_in6 *psin; - struct sockaddr_storage *kss, *pkss; + struct sockaddr_storage *pkss; int i; - INP_WUNLOCK(inp); - - CTR2(KTR_MLD, "%s: loading %lu source list entries", - __func__, (unsigned long)msfr.msfr_nsrcs); - kss = malloc(sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs, - M_TEMP, M_WAITOK); - error = copyin(msfr.msfr_srcs, kss, - sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs); - if (error) { - free(kss, M_TEMP); - return (error); - } - - INP_WLOCK(inp); - /* * Mark all source filters as UNDEFINED at t1. * Restore new group filter mode, as im6f_leave() @@ -2615,7 +2610,6 @@ break; lims->im6sl_st[1] = imf->im6f_st[1]; } - free(kss, M_TEMP); } if (error) @@ -2650,6 +2644,9 @@ out_in6p_locked: INP_WUNLOCK(inp); + IN6_MULTI_UNLOCK(); +out_in6p_unlocked: + free(kss, M_TEMP); return (error); }